diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 47062d60aaed1..b50931f15826c 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -3436,8 +3436,12 @@ static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap) { return explicitSet == defaultSet; } -/// Returns true if the \p explictMap is broadcasted with respect to the -/// \p defaultMap. +/// Check if the user defined map is valid broadcast map. Here broadcast +/// indexing maps are defined in context of corresponding default indexing maps +/// for the given Op. This way the check becomes very simple i.e just check the +/// number of result dims. +/// Returns true if the explictMap is broadcasted with respect to the +/// defaultMap. static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) { return explictMap.getNumResults() < defaultMap.getNumResults(); } @@ -3458,10 +3462,6 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, return matmulOp->emitOpError() << "Unexpected dim expression in map result."; - // Check if the user defined map is valid broadcast map. Here broadcast - // indexing maps are defined in context of corresponding default indexing maps - // for the given Op. This way the check becomes very simple i.e just check the - // number of result dims. if (isBroadcasted(opIndexingMap, defaultIndexingMap)) { if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) { return matmulOp->emitOpError() @@ -3527,8 +3527,7 @@ static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp, } /// Verifies the broadcast and transpose semantic specified by the explicit -/// indexing map for the BatchMatmulOp \p op for each operand specified by \p -/// opIndex. +/// indexing map for the BatchMatmulOp op for each operand specified by opIndex. static LogicalResult verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp, unsigned opIndex) { @@ -3934,7 +3933,7 @@ bool BatchMatmulOp::hasUserDefinedMaps() { return defaultMaps != explicitMaps; } -/// Returns true if the given broadcast map \p bcastMap is valid for this op. +/// Returns true if the given broadcast map bcastMap is valid for this op. bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) { assert(bcastMap.getNumResults() < 3 && "Expected less than 3 result dim expr."); @@ -3960,16 +3959,15 @@ void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, RegionBuilderHelper helper(b, block); SmallVector yields; - Value value1 = - helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(), - block.getArgument(0)); - Value value2 = - helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(), - block.getArgument(1)); - Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2); - Value value4 = - helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3); - yields.push_back(value4); + auto toType = block.getArgument(2).getType(); + Value castValA = + helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0)); + Value castValB = + helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1)); + Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB); + Value addVal = + helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal); + yields.push_back(addVal); helper.yieldOutputs(yields); }