Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Brgemm register tiling for bf16 type #1005

Merged
merged 17 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 109 additions & 113 deletions lib/TPP/Transforms/BrgemmLinalgTiling.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
//===- BrgemmLinalgTiling.cpp -----------------------------------------*- C++-*-===//
//===- BrgemmLinalgTiling.cpp -----------------------------------------*-
//C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand Down Expand Up @@ -43,160 +44,152 @@ using namespace mlir::tpp;

namespace mlir {
namespace tpp {
struct LinalgOpTiling : OpRewritePattern<linalg::BatchReduceMatmulOp> {
using OpRewritePattern<linalg::BatchReduceMatmulOp>::OpRewritePattern;

template <typename BrgemmOp>
struct LinalgOpTiling : OpRewritePattern<BrgemmOp> {
using OpRewritePattern<BrgemmOp>::OpRewritePattern;

LinalgOpTiling(MLIRContext *ctx, BrgemmLinalgTilingOptions tilingoptions)
: OpRewritePattern(ctx), options(tilingoptions) {}
: OpRewritePattern<BrgemmOp>(ctx), options(tilingoptions) {}

LogicalResult matchAndRewrite(linalg::BatchReduceMatmulOp brgemmOp,
LogicalResult matchAndRewrite(BrgemmOp brgemmOp,
PatternRewriter &rewriter) const override {

if (!brgemmOp.hasPureBufferSemantics())
return failure();
// Get the register blocking tile shape from the user input
SmallVector<int64_t> tileShapeM(options.registerTileShape.begin(),
options.registerTileShape.end());

if (tileShapeM.size() != 2)
return failure();
// Check whether the tile sizes are valid
if (options.registerTileShape.size() != 3 &&
options.registerTileShape.size() != 2)
return failure();

SmallVector<int64_t> tileShapeN(2);
tileShapeN[0] = 1;
tileShapeN[1] = tileShapeM[1];
tileShapeM[1] = 1;
// Check the whether the operation is brmatmul fp32 or bf16 type using
// reduction count
SmallVector<utils::IteratorType> brgemmIteratorTypes =
brgemmOp.getIteratorTypesArray();
int reductionCount =
std::count(brgemmIteratorTypes.begin(), brgemmIteratorTypes.end(),
utils::IteratorType::reduction);
if (reductionCount != 2 && reductionCount != 3)
return failure();

// Stores the M, N, and K Tile Sizes
// Get the register blocking tile shape from the user input
SmallVector<int64_t> mxnxkTile(3);
// Stores the M, and N Tile Sizes
SmallVector<int64_t> mxnTile(2);
for (size_t i = 0; i < options.registerTileShape.size(); i++) {
mxnxkTile[i] = options.registerTileShape[i];
}

// Set the K tile to 1, if the user not provided (it is fp32 target)
if (options.registerTileShape.size() == 2)
mxnxkTile[2] = 1;

// k-tile size adjusted based on the vnni layout for bf16 type
auto tensorShape =
dyn_cast<MemRefType>(brgemmOp.getOperand(0).getType()).getShape();
if (tensorShape.size() == 4 && options.registerTileShape.size() == 3) {
mxnxkTile[2] = mxnxkTile[2] / tensorShape[3];
}

mxnxkTile[0] = tileShapeM[0];
mxnxkTile[1] = tileShapeN[1];
mxnxkTile[2] = tileShapeM[1];
mxnTile[0] = tileShapeM[0];
mxnTile[1] = tileShapeN[1];

// To assist in calculating the argument and step values for the tiled loop.
SmallVector<int64_t> boundariesOne{1,
static_cast<long>(tileShapeM.size() - 1),
static_cast<long>(mxnxkTile.size() - 1)};

SmallVector<int64_t> tileSizesIndex{static_cast<long>(tileShapeM.size()),
static_cast<long>(tileShapeN.size()),
static_cast<long>(mxnTile.size())};
SmallVector<SmallVector<int64_t>> tileshapes{tileShapeM, tileShapeN, mxnTile};
SmallVector<int> swap_i = {0, 2, 1};
size_t i = 0;
SmallVector<int> swap_i = {0, 2, 1};
std::map<int, std::map<int, Value>> inductionVars;

// For M, N, and K loops
scf::ForOp innermostForLoop;
// For brgemm reduction loop
scf::ForOp reductionForLoop;

// Creating the tiled loops
for (auto itrShapeM = mxnxkTile.begin(); itrShapeM != mxnxkTile.end();
itrShapeM++, i++) {
int index = swap_i[i] / boundariesOne[swap_i[i]];
int offset = swap_i[i] / (mxnxkTile.size() - 1);

int operandSize =
dyn_cast<MemRefType>(brgemmOp.getOperand(index).getType())
.getShape()
.size();
int effectiveOffset = operandSize - tileSizesIndex[index] + offset;
for (auto itrShapeMNK = mxnxkTile.begin(); itrShapeMNK != mxnxkTile.end();
itrShapeMNK++, i++) {
auto upperBound =
dyn_cast<MemRefType>(brgemmOp.getOperand(index).getType())
.getShape()[effectiveOffset];
dyn_cast<MemRefType>(brgemmOp.getOperand(swap_i[i]).getType())
.getShape()[1];
// Tile size should not be greater than the upperBound
if ((*itrShapeMNK) > upperBound)
return failure();

Location loc = brgemmOp.getLoc();
Value zeroCst = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value ubCstTiledLoop = rewriter.create<arith::ConstantIndexOp>(loc, upperBound);
//Tile size should not be greater than the upperBound
if ((*itrShapeM) > upperBound)
return failure();
Value stepCstTiledLoop = rewriter.create<arith::ConstantIndexOp>(loc, *itrShapeM);
Value ubCstTiledLoop =
rewriter.create<arith::ConstantIndexOp>(loc, upperBound);
Value stepCstTiledLoop =
rewriter.create<arith::ConstantIndexOp>(loc, *itrShapeMNK);
// Creates M, N, and K tile loops
scf::ForOp loopOp = rewriter.create<scf::ForOp>(brgemmOp.getLoc(),
zeroCst, ubCstTiledLoop, stepCstTiledLoop);
scf::ForOp loopOp = rewriter.create<scf::ForOp>(
brgemmOp.getLoc(), zeroCst, ubCstTiledLoop, stepCstTiledLoop);
rewriter.setInsertionPointToStart(loopOp.getBody());
int indexTwo = offset;
int operandSizeTwo =
dyn_cast<MemRefType>(brgemmOp.getOperand(indexTwo).getType())
.getShape()
.size();
int effectiveOffsetTwo = operandSizeTwo - tileSizesIndex[index] + index;

inductionVars[index][effectiveOffset] = loopOp.getInductionVar();

inductionVars[indexTwo][effectiveOffsetTwo] = loopOp.getInductionVar();
int indexThree = mxnTile.size();
int effectiveOffsetThree =
index +
dyn_cast<MemRefType>(brgemmOp.getOperand(indexThree).getType())
.getShape()
.size() -
tileSizesIndex[indexThree];
if (inductionVars[indexThree][effectiveOffsetThree] == NULL) {
inductionVars[indexThree][effectiveOffsetThree] =
loopOp.getInductionVar();
}

innermostForLoop = loopOp;
if ((mxnxkTile.size() - 1) == (i + 1)) {
//Creates the brgemm reduction loop

// Stores the induction variable with respect to the operands mapping it's
// subview.
if (i == 0) { // Stores iv for M loop
inductionVars[0][1] = loopOp.getInductionVar();
inductionVars[2][0] = loopOp.getInductionVar();
} else if (i == 1) { //stores iv for N loop, creates batch loop, and maps iv of batch loop
inductionVars[1][2] = loopOp.getInductionVar();
inductionVars[2][1] = loopOp.getInductionVar();
// Creates reduction loop after the N loop
Value ubCstReduction = rewriter.create<arith::ConstantIndexOp>(
loc, dyn_cast<MemRefType>(brgemmOp.getOperand(0).getType())
.getShape()[0]);
Value stepCstReduction = rewriter.create<arith::ConstantIndexOp>(loc, 1);
Value stepCstReduction =
rewriter.create<arith::ConstantIndexOp>(loc, 1);
scf::ForOp redloopOp = rewriter.create<scf::ForOp>(
brgemmOp.getLoc(), zeroCst, ubCstReduction, stepCstReduction);
rewriter.setInsertionPointToStart(redloopOp.getBody());
reductionForLoop = redloopOp;
inductionVars[0][0] = redloopOp.getInductionVar();
inductionVars[1][0] = redloopOp.getInductionVar();

} else if (i == 2) { // stores iv for k-loop
inductionVars[0][2] = loopOp.getInductionVar();
inductionVars[1][1] = loopOp.getInductionVar();
}
}

// DS to assist while creating new subviews with correct indices and shapes
SmallVector<int64_t> mxkTile(2);
SmallVector<int64_t> kxnTile(2);
SmallVector<int64_t> mxnTile(2);
mxkTile[0] = mxnxkTile[0];
mxkTile[1] = mxnxkTile[2];
kxnTile[0] = mxnxkTile[2];
kxnTile[1] = mxnxkTile[1];
mxnTile[0] = mxnxkTile[0];
mxnTile[1] = mxnxkTile[1];

SmallVector<SmallVector<int64_t>> tileshapes{mxkTile, kxnTile, mxnTile};
// Creating subviews
SmallVector<SmallVector<int64_t>> tiles = {tileShapeM, tileShapeN};
for (size_t i = 0; i < brgemmOp.getNumOperands(); i++) {
SmallVector<int64_t> indices;
auto input = brgemmOp.getOperand(i);
auto operandType = input.getType();
SmallVector<OpFoldResult> offsets;
size_t k = 0;
auto tileItr = tileshapes[i].begin();
auto tensorShape = dyn_cast<MemRefType>(operandType).getShape();
SmallVector<int64_t> indices;
SmallVector<OpFoldResult> shape;
SmallVector<OpFoldResult> strides;

auto input = brgemmOp.getOperand(i);
auto tensorShape = dyn_cast<MemRefType>(input.getType()).getShape();
auto tileItr = tileshapes[i].begin();

// Iterates over the shape of each tensor and update its offsets, indices,
// shapes, strides with respect to tile sizes
for (size_t j = 0; j < tensorShape.size(); j++) {
if (j < tensorShape.size() - tileSizesIndex[i]) {
if (j == ((tensorShape.size() - tileSizesIndex[i]) - 1) &&
i < (brgemmOp.getNumOperands() - 1)) {
offsets.push_back(reductionForLoop.getInductionVar());
indices.push_back(tensorShape[j] / tensorShape[j]);
shape.push_back(rewriter.getIndexAttr(tensorShape[j] / tensorShape[j]));
strides.push_back(rewriter.getIndexAttr(1));

} else {
offsets.push_back(rewriter.getIndexAttr(0));
indices.push_back(tensorShape[j]);
shape.push_back(rewriter.getIndexAttr(tensorShape[j]));
strides.push_back(rewriter.getIndexAttr(1));
}
} else {
shape.push_back(rewriter.getIndexAttr(*tileItr));
if (j == 0 && (i < 2)) { // Updates the batch dimension
offsets.push_back(inductionVars[i][j]);
indices.push_back(1);
shape.push_back(rewriter.getIndexAttr(1));
strides.push_back(rewriter.getIndexAttr(1));
} else if (j < 3) { // Updates the M, N, and K dimensions
offsets.push_back(inductionVars[i][j]);
indices.push_back((*tileItr));
shape.push_back(rewriter.getIndexAttr(*tileItr));
strides.push_back(rewriter.getIndexAttr(1));
offsets.push_back(
inductionVars[i][tensorShape.size() - tileSizesIndex[i] + k]);
k++;
tileItr++;
} else { // Just copies the vnni layout dimensions
offsets.push_back(rewriter.getIndexAttr(0));
indices.push_back(tensorShape[j]);
shape.push_back(rewriter.getIndexAttr(tensorShape[j]));
strides.push_back(rewriter.getIndexAttr(1));
}
}

auto subview = rewriter.create<memref::SubViewOp>(
brgemmOp.getLoc(), MemRefType(),
input, offsets, shape, strides);
brgemmOp.getLoc(), MemRefType(), input, offsets, shape, strides);
brgemmOp.setOperand(i, subview);
}

Expand All @@ -214,11 +207,14 @@ struct LinalgOpTiling : OpRewritePattern<linalg::BatchReduceMatmulOp> {
};

void populateBrgemmLinalgTilingPatterns(RewritePatternSet &patterns,
BrgemmLinalgTilingOptions options) {
patterns.add<LinalgOpTiling>(patterns.getContext(), options);
BrgemmLinalgTilingOptions options) {
patterns.add<LinalgOpTiling<linalg::GenericOp>,
LinalgOpTiling<linalg::BatchReduceMatmulOp>>(
patterns.getContext(), options);
}

struct BrgemmLinalgTiling : public tpp::impl::BrgemmLinalgTilingBase<BrgemmLinalgTiling> {
struct BrgemmLinalgTiling
: public tpp::impl::BrgemmLinalgTilingBase<BrgemmLinalgTiling> {

using BrgemmLinalgTilingBase::BrgemmLinalgTilingBase;

Expand Down
30 changes: 30 additions & 0 deletions test/Integration/tile-brgemm-linalg-matmul-bf16.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// RUN: tpp-run -e register_tile_bf16 --entry-point-result=void -print %s > %t.1
// RUN: tpp-opt %s --tile-brgemm-linalg="registerBlocking=32,32,32" -convert-linalg-to-xsmm | tpp-run -e register_tile_bf16 --entry-point-result=void -print > %t.2
// RUN: diff %t.1 %t.2
// RUN: rm %t.1 %t.2

#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
module {
memref.global "private" constant @__constant_32x16x32x2xbf16 : memref<32x16x32x2xbf16> = dense<1.000000e+00> {alignment = 64 : i64}
func.func @register_tile_bf16(%arg0: memref<8x32x32x32xbf16>) -> memref<8x32x32x32xbf16> {
%cst = arith.constant 0.000000e+00 : bf16
%0 = memref.get_global @__constant_32x16x32x2xbf16 : memref<32x16x32x2xbf16>
%alloc = memref.alloc() {alignment = 64 : i64} : memref<8x32x32x32xbf16>
%expand_shape = memref.expand_shape %arg0 [[0], [1], [2], [3, 4]] output_shape [8, 32, 32, 16, 2] : memref<8x32x32x32xbf16> into memref<8x32x32x16x2xbf16>
scf.forall (%arg1, %arg2) in (8, 32) {
%subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>>
linalg.fill ins(%cst : bf16) outs(%subview : memref<32x32xbf16, strided<[32, 1], offset: ?>>)
%subview_0 = memref.subview %expand_shape[%arg1, 0, 0, 0, 0] [1, 32, 32, 16, 2] [1, 1, 1, 1, 1] : memref<8x32x32x16x2xbf16> to memref<32x32x16x2xbf16, strided<[1024, 32, 2, 1], offset: ?>>
linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%subview_0, %0 : memref<32x32x16x2xbf16, strided<[1024, 32, 2, 1], offset: ?>>, memref<32x16x32x2xbf16>) outs(%subview : memref<32x32xbf16, strided<[32, 1], offset: ?>>) {
^bb0(%in: bf16, %in_1: bf16, %out: bf16):
%1 = arith.mulf %in, %in_1 : bf16
%2 = arith.addf %out, %1 : bf16
linalg.yield %2 : bf16
}
}
return %alloc : memref<8x32x32x32xbf16>
}
}

Loading