Skip to content

Commit

Permalink
*Renames few variables and updates few comments
Browse files Browse the repository at this point in the history
  • Loading branch information
shahidact committed Feb 6, 2025
1 parent 237e041 commit 048a481
Showing 1 changed file with 17 additions and 19 deletions.
36 changes: 17 additions & 19 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3436,8 +3436,12 @@ static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap) {
return explicitSet == defaultSet;
}

/// Returns true if the \p explictMap is broadcasted with respect to the
/// \p defaultMap.
/// Check if the user defined map is valid broadcast map. Here broadcast
/// indexing maps are defined in context of corresponding default indexing maps
/// for the given Op. This way the check becomes very simple i.e just check the
/// number of result dims.
/// Returns true if the explictMap is broadcasted with respect to the
/// defaultMap.
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
return explictMap.getNumResults() < defaultMap.getNumResults();
}
Expand All @@ -3458,10 +3462,6 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
return matmulOp->emitOpError()
<< "Unexpected dim expression in map result.";

// Check if the user defined map is valid broadcast map. Here broadcast
// indexing maps are defined in context of corresponding default indexing maps
// for the given Op. This way the check becomes very simple i.e just check the
// number of result dims.
if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
return matmulOp->emitOpError()
Expand Down Expand Up @@ -3527,8 +3527,7 @@ static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp,
}

/// Verifies the broadcast and transpose semantic specified by the explicit
/// indexing map for the BatchMatmulOp \p op for each operand specified by \p
/// opIndex.
/// indexing map for the BatchMatmulOp op for each operand specified by opIndex.
static LogicalResult
verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp,
unsigned opIndex) {
Expand Down Expand Up @@ -3934,7 +3933,7 @@ bool BatchMatmulOp::hasUserDefinedMaps() {
return defaultMaps != explicitMaps;
}

/// Returns true if the given broadcast map \p bcastMap is valid for this op.
/// Returns true if the given broadcast map bcastMap is valid for this op.
bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
assert(bcastMap.getNumResults() < 3 &&
"Expected less than 3 result dim expr.");
Expand All @@ -3960,16 +3959,15 @@ void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
RegionBuilderHelper helper(b, block);
SmallVector<Value> yields;

Value value1 =
helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
block.getArgument(0));
Value value2 =
helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
block.getArgument(1));
Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
Value value4 =
helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
yields.push_back(value4);
auto toType = block.getArgument(2).getType();
Value castValA =
helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
Value castValB =
helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
Value addVal =
helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
yields.push_back(addVal);
helper.yieldOutputs(yields);
}

Expand Down

0 comments on commit 048a481

Please sign in to comment.