diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index 244db23925ab3..98a5fd278a997 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -710,7 +710,10 @@ def LinalgStructuredInterface >, InterfaceMethod< /*desc=*/[{ - Return true if the user has supplied an explicit indexing maps for this op. + Returns true if the user has supplied explicit indexing maps that are + different from default indexing maps for this op. Returns `false` otherwise. + Note, if the user define maps that are identical to the default maps, + this method returns `false`. }], /*retTy=*/"bool", /*methodName=*/"hasUserDefinedMaps", diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml index b0ea1f7695581..496a323249e85 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -1472,75 +1472,6 @@ structured_op: !LinalgStructuredOpConfig - !ScalarExpression scalar_arg: rhs --- !LinalgOpConfig -metadata: !LinalgOpMetadata - name: batch_matmul - cpp_class_name: BatchMatmulOp - doc: |- - Performs a batched matrix multiplication of two 3D inputs. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - implements: - - LinalgContractionOpInterface -structured_op: !LinalgStructuredOpConfig - args: - - !LinalgOperandDefConfig - name: A - kind: input_tensor - type_var: T1 - shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)> - - !LinalgOperandDefConfig - name: B - kind: input_tensor - type_var: T2 - shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)> - - !LinalgOperandDefConfig - name: C - kind: output_tensor - type_var: U - shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)> - indexing_maps: !LinalgIndexingMapsConfig - static_indexing_maps: - - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)> - - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)> - - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)> - iterator_types: - - parallel - - parallel - - parallel - - reduction - assignments: - - !ScalarAssign - arg: C - value: !ScalarExpression - scalar_fn: - kind: binary - fn_name: add - operands: - - !ScalarExpression - scalar_arg: C - - !ScalarExpression - scalar_fn: - kind: binary - fn_name: mul - operands: - - !ScalarExpression - scalar_fn: - kind: type - fn_name: cast_signed - type_var: U - operands: - - !ScalarExpression - scalar_arg: A - - !ScalarExpression - scalar_fn: - kind: type - fn_name: cast_signed - type_var: U - operands: - - !ScalarExpression - scalar_arg: B ---- !LinalgOpConfig metadata: !LinalgOpMetadata name: batch_matmul_transpose_a cpp_class_name: BatchMatmulTransposeAOp diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index e3d122189f8b7..110ed7d2fc00e 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -674,8 +674,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [ static unsigned getNumRegionArgs(); std::string getLibraryCallName(); bool hasDynamicIndexingMaps(); - /// Check if the op has broadcast and/or transpose semantic. Returns true if the - /// user defined indexing maps are not equal to default map. + /// Returns true if the user defined indexing maps are not equal to default maps. bool hasUserDefinedMaps(); }]; } @@ -816,6 +815,129 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [ }]; } +//===----------------------------------------------------------------------===// +// Op definition for BatchMatmulOp +//===----------------------------------------------------------------------===// + +def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSizedOperandSegments], + /*extraInterfaces=*/[LinalgContractionOpInterface])> { + + let summary = [{Performs a batched matrix multiplication of two 3D inputs.}]; + let description = [{Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + + Broadcast and Transpose semantics can be appiled by specifying the explicit attribute + 'indexing_maps' as shown below. This is a list attribute, so must include maps for all + arguments if specified. + + Example Transpose: + ``` + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose + affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>) + outs(%arg2: memref<2x3x7xf32>) + ``` + + Example Broadcast: + ``` + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast + affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>) + outs(%arg2: memref<2x3x7xf32>) + ``` + + Example Broadcast and Transpose: + ``` + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast + affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>) + outs(%arg2: memref<2x3x7xf32>) + ``` +}]; + + let arguments = (ins + Variadic:$inputs, + Variadic:$outputs, + DefaultValuedOptionalAttr:$indexing_maps + ); + let results = (outs Variadic:$result_tensors); + let regions = (region AnyRegion:$region); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder< + (ins "ValueRange":$inputs, "ValueRange":$outputs, + CArg<"ArrayRef", "{}">:$attributes), + [{ + buildBatchMatmulOp($_builder, $_state, std::nullopt, inputs, outputs, + attributes, BatchMatmulOp::getRegionBuilder(), + BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext())); + }]>, + OpBuilder< + (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, + "ValueRange":$outputs, + CArg<"ArrayRef", "{}">:$attributes), + [{ + buildBatchMatmulOp($_builder, $_state, resultTensorTypes, + inputs, outputs, attributes, BatchMatmulOp::getRegionBuilder(), + BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext())); + }]>, + OpBuilder< + (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands, + CArg<"ArrayRef", "{}">:$attributes), + [{ + $_state.addOperands(operands); + $_state.addAttributes(attributes); + $_state.addTypes(resultTensorTypes); + (void)$_state.addRegion(), + BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext()); + }]> + + ]; + let hasCustomAssemblyFormat = 1; + let hasFolder = 1; + let hasVerifier = 1; + + let extraClassDeclaration = structuredOpsBaseDecls # [{ + + SmallVector getIteratorTypesArray(); + static void regionBuilder(ImplicitLocOpBuilder &b, + Block &block, ArrayRef attrs); + static std::function)> + getRegionBuilder() { + return regionBuilder; + } + + /// Returns a list with default AffineMap(s), i.e. without broadcasts and transpositions. + static SmallVector getDefaultIndexingMaps(MLIRContext *context); + + /// Returns true if the given broadcast map \p bcastMap is valid for this op. + bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true); + + ::mlir::MutableOperandRange getDpsInitsMutable() { + return getOutputsMutable(); + } + + // Generic methods. + static unsigned getNumRegionArgs(); + bool hasDynamicIndexingMaps() { return true; } + std::string getLibraryCallName(); + /// Returns true if the user defined indexing maps are not equal to default maps. + bool hasUserDefinedMaps(); + }]; +} + + //===----------------------------------------------------------------------===// // Named Linalg ops, implemented as a declarative configurations of generic ops. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 2f1c22b10dd36..b50931f15826c 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -200,6 +200,23 @@ static void buildMatmulOp(OpBuilder &b, OperationState &state, attributes, regionBuilder); } +static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, + std::optional resultTensorTypes, + ValueRange inputs, ValueRange outputs, + ArrayRef attributes, + RegionBuilderFn regionBuilder, + ArrayRef indexingMaps) { + // Initialize indexingMaps attribute, for BatchMatmulOp. + SmallVector indexingMapsAttrVal; + indexingMapsAttrVal = + llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute { + return AffineMapAttr::get(map); + }); + state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal)); + return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs, + attributes, regionBuilder); +} + /// Common parsing used for both named structured ops created by ods-gen and by /// manually defined C++ ops. Does not handle regions. static ParseResult @@ -3409,19 +3426,22 @@ Operation *LinalgDialect::materializeConstant(OpBuilder &builder, return arith::ConstantOp::materialize(builder, value, type, loc); } -/// Returns true if the result AffineExpr of the \p explicitMap is same as \p -/// defaultMap. -static bool isValidResultDimExprs(AffineMap explictMap, AffineMap defaultMap) { - auto explicitRange = explictMap.getResults(); - auto defaultRange = defaultMap.getResults(); +// Returns true if the result expression of `subMap` are a subset of `fullMap`. +static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap) { + auto explicitRange = subMap.getResults(); + auto defaultRange = fullMap.getResults(); DenseSet explicitSet(explicitRange.begin(), explicitRange.end()); DenseSet defaultSet(defaultRange.begin(), defaultRange.end()); llvm::set_union(explicitSet, defaultSet); return explicitSet == defaultSet; } -/// Returns true if the \p explictMap is broadcasted with respect to the -/// \p defaultMap. +/// Check if the user defined map is valid broadcast map. Here broadcast +/// indexing maps are defined in context of corresponding default indexing maps +/// for the given Op. This way the check becomes very simple i.e just check the +/// number of result dims. +/// Returns true if the explictMap is broadcasted with respect to the +/// defaultMap. static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) { return explictMap.getNumResults() < defaultMap.getNumResults(); } @@ -3438,11 +3458,10 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, auto opIndexingMap = opIndexingMaps[opIndex]; auto defaultIndexingMap = defaultIndexingMaps[opIndex]; // Check general validity of indexing map results. - if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap)) + if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap)) return matmulOp->emitOpError() << "Unexpected dim expression in map result."; - // Check if the requested broadcast is valid. if (isBroadcasted(opIndexingMap, defaultIndexingMap)) { if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) { return matmulOp->emitOpError() @@ -3453,6 +3472,87 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, return success(); } +// Check general validity of input indexing map. +static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp, + AffineMap opIndexingMap, + AffineMap defaultIndexingMap, bool isLHS) { + // Check the result dims are valid. + if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap)) + return batchMatmulOp->emitOpError() + << "Unexpected result dim expression (outside the set of default " + "result dims)."; + + // Check for valid number of result dims of input maps. + if (opIndexingMap.getNumResults() > 3) + return batchMatmulOp->emitOpError() + << "no. of result dim expressions exceeds 3."; + + auto hasValidBatchDim = [](AffineMap map) { + AffineExpr batchDim = map.getResult(0); + return batchDim.isFunctionOfDim(0); + }; + + // Check if the requested broadcast is valid. + if (isBroadcasted(opIndexingMap, defaultIndexingMap)) { + if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS)) + return batchMatmulOp->emitOpError() << "Invalid broadcast requested."; + } else if (!hasValidBatchDim(opIndexingMap)) { + return batchMatmulOp->emitOpError() + << "Invalid batch dimension expression."; + } + return success(); +} + +/// This function checks if the given AffineMap for the output of a +/// BatchMatmulOp has exactly 3 result dimensions and if the output map result +/// dimensions are valid. +static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp, + AffineMap opIndexingMap) { + if (opIndexingMap.getNumResults() != 3) + return batchMatmulOp->emitOpError() + << "expects 3 dims, but got (" << opIndexingMap.getNumResults() + << ")."; + + auto areValidOutputResultDim = [](AffineMap outputMap) { + return outputMap.getResult(0).isFunctionOfDim(0) && + outputMap.getResult(1).isFunctionOfDim(1) && + outputMap.getResult(2).isFunctionOfDim(2); + }; + + if (!areValidOutputResultDim(opIndexingMap)) + return batchMatmulOp->emitOpError() + << "Invalid output map result dimension."; + + return success(); +} + +/// Verifies the broadcast and transpose semantic specified by the explicit +/// indexing map for the BatchMatmulOp op for each operand specified by opIndex. +static LogicalResult +verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp, + unsigned opIndex) { + SmallVector opIndexingMaps = + batchMatmulOp.getIndexingMapsArray(); + SmallVector defaultIndexingMaps = + batchMatmulOp.getDefaultIndexingMaps(batchMatmulOp->getContext()); + + if (opIndexingMaps.size() != 3) + return batchMatmulOp->emitOpError() + << "Indexing_map attribute must have 3 affine maps."; + + auto opIndexingMap = opIndexingMaps[opIndex]; + auto defaultIndexingMap = defaultIndexingMaps[opIndex]; + + if (opIndex == 2 && failed(verifyOutputMap(batchMatmulOp, opIndexingMap))) + return failure(); + + if (failed(verifyInputMaps(batchMatmulOp, opIndexingMap, defaultIndexingMap, + opIndex == 0))) + return failure(); + + return success(); +} + namespace mlir { namespace linalg { @@ -3798,5 +3898,166 @@ Speculation::Speculatability ContractOp::getSpeculatability() { return getGenericSpeculatabilityImpl(cast(getOperation())); } +//===----------------------------------------------------------------------===// +// Implementation of BatchMatmulOp +//===----------------------------------------------------------------------===// +SmallVector +BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) { + AffineExpr d0, d1, d2, d3; + SmallVector indexingMaps; + bindDims(context, d0, d1, d2, d3); + indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context)); + indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context)); + indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d2}, context)); + return indexingMaps; +} + +SmallVector BatchMatmulOp::getIteratorTypesArray() { + return SmallVector{ + utils::IteratorType::parallel, utils::IteratorType::parallel, + utils::IteratorType::parallel, utils::IteratorType::reduction}; +} + +unsigned BatchMatmulOp::getNumRegionArgs() { return 3; } + +std::string BatchMatmulOp::getLibraryCallName() { + return generateLibraryCallName(getOperation()); +} + +/// Check if the op has broadcast and/or transpose semantic. Returns true if +/// the user defined indexing maps are not equal to default map. +bool BatchMatmulOp::hasUserDefinedMaps() { + SmallVector defaultMaps = + getDefaultIndexingMaps(this->getContext()); + SmallVector explicitMaps = getIndexingMapsArray(); + return defaultMaps != explicitMaps; +} + +/// Returns true if the given broadcast map bcastMap is valid for this op. +bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) { + assert(bcastMap.getNumResults() < 3 && + "Expected less than 3 result dim expr."); + bool isValid = false; + enum Indices { batchPos, mPos, nPos, kPos }; + if (bcastMap.getNumResults() == 1) { + AffineExpr exp = bcastMap.getResult(0); + isValid = exp.isFunctionOfDim(kPos); + } else if (bcastMap.getNumResults() == 2) { + AffineExpr exp0 = bcastMap.getResult(0); + AffineExpr exp1 = bcastMap.getResult(1); + isValid = isLHS + ? (exp0.isFunctionOfDim(mPos) && exp1.isFunctionOfDim(kPos)) + : (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos)); + } + return isValid; +} + +void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, + ArrayRef attrs) { + assert(block.getNumArguments() == 3 && + "BatchMatmulOp regionBuilder expects 3 (>=0) args"); + RegionBuilderHelper helper(b, block); + SmallVector yields; + + auto toType = block.getArgument(2).getType(); + Value castValA = + helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0)); + Value castValB = + helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1)); + Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB); + Value addVal = + helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal); + yields.push_back(addVal); + helper.yieldOutputs(yields); +} + +ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) { + SmallVector indexingMapsAttr; + Attribute mapAttr; + if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) { + if (parser.parseEqual()) + return failure(); + + if (parser.parseLSquare()) + return failure(); + + do { + if (parser.parseAttribute(mapAttr)) + return failure(); + if (!isa(mapAttr)) { + return parser.emitError(parser.getCurrentLocation(), + "expected affine map attribute"); + } + indexingMapsAttr.push_back(mapAttr); + + if (parser.parseOptionalComma()) + break; + } while (true); + + if (parser.parseRSquare()) + return failure(); + } + // Initialize indexingMaps, if not supplied explicitly. + if (indexingMapsAttr.empty()) { + indexingMapsAttr = llvm::map_to_vector( + BatchMatmulOp::getDefaultIndexingMaps(parser.getContext()), + [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); }); + } + result.addAttribute("indexing_maps", + parser.getBuilder().getArrayAttr(indexingMapsAttr)); + + return ::parseNamedStructuredOp(parser, result, + BatchMatmulOp::getNumRegionArgs(), + BatchMatmulOp::getRegionBuilder()); +} + +void BatchMatmulOp::print(OpAsmPrinter &p) { + SmallVector elidedAttrs = { + "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"}; + ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(), + elidedAttrs); + + SmallVector indexingMaps = llvm::map_to_vector( + BatchMatmulOp::getDefaultIndexingMaps(getContext()), + [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); }); + if (!llvm::equal(getIndexingMaps(), indexingMaps)) { + p << " indexing_maps = ["; + llvm::interleaveComma(getIndexingMaps(), p, + [&](Attribute attr) { p.printAttribute(attr); }); + p << "]"; + } +} + +/// Verify the user defined indexing maps. +LogicalResult BatchMatmulOp::verify() { + // Verification of pure batch_matmul is handled by + // verifyStructuredOpInterface(). + if (!hasUserDefinedMaps()) + return success(); + + for (unsigned opIndex = 0; opIndex < 3; opIndex++) { + if (failed(verifyExtendedBatchMatmulSemantic(*this, opIndex))) + return failure(); + } + return success(); +} + +LogicalResult BatchMatmulOp::fold(FoldAdaptor, + SmallVectorImpl &) { + return memref::foldMemRefCast(*this); +} + +void BatchMatmulOp::getEffects( + SmallVectorImpl> + &effects) { + if (hasPureTensorSemantics()) + return; + getGenericEffectsImpl(effects, cast(getOperation())); +} + +Speculation::Speculatability BatchMatmulOp::getSpeculatability() { + return getGenericSpeculatabilityImpl(cast(getOperation())); +} + } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp index ed1685a9cb9e6..7f9a0f7a6ca43 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp @@ -138,6 +138,16 @@ transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, FailureOr linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, const ControlBlockPackMatmulFn &controlPackMatmul) { + // Check to not let go the batch_matmul with extended semantic, through this + // transform. + if (auto *batchMatmulOp = dyn_cast(&linalgOp)) { + if (batchMatmulOp->hasUserDefinedMaps()) { + return rewriter.notifyMatchFailure( + *batchMatmulOp, + "only batch_matmul ops with non-extended semantics are supported"); + } + } + if (linalgOp.hasPureBufferSemantics()) return rewriter.notifyMatchFailure(linalgOp, "require tensor semantics"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index 9b97865990bfd..bd4ffabfbb929 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -906,6 +906,10 @@ struct RankReduceContractionOps : OpRewritePattern { LogicalResult matchAndRewrite(FromOpTy contractionOp, PatternRewriter &rewriter) const override { + if (contractionOp.hasUserDefinedMaps()) { + return rewriter.notifyMatchFailure( + contractionOp, "ops with user-defined maps are not supported"); + } auto loc = contractionOp.getLoc(); auto inputs = contractionOp.getDpsInputs(); @@ -935,7 +939,8 @@ struct RankReduceContractionOps : OpRewritePattern { loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs}, ValueRange{collapsedInit}); for (auto attr : contractionOp->getAttrs()) { - if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName) + if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName || + attr.getName() == "indexing_maps") continue; collapsedOp->setAttr(attr.getName(), attr.getValue()); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp index 6b934f7e8157d..e624f589917d1 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp @@ -88,6 +88,11 @@ FailureOr mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp batchMatmulOp, bool transposeLHS) { + if (batchMatmulOp.hasUserDefinedMaps()) { + return rewriter.notifyMatchFailure( + batchMatmulOp, "ops with user-defined maps are not supported"); + } + if (!bufferization::hasTensorSemantics(batchMatmulOp)) return rewriter.notifyMatchFailure( batchMatmulOp, "only matmul ops with tensors are supported"); diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index c95cd5eecfffc..040663c882a08 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -484,24 +484,6 @@ def batch_mmt4d( ) * TypeFn.cast_signed(TV.AccumType, rhs[D.b, D.n, D.k, D.n0, D.k0]) -@linalg_structured_op -def batch_matmul( - A=TensorDef(T1, Batch, S.M, S.K), - B=TensorDef(T2, Batch, S.K, S.N), - C=TensorDef(U, Batch, S.M, S.N, output=True), -): - """Performs a batched matrix multiplication of two 3D inputs. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - domain(D.b, D.m, D.n, D.k) - implements(ContractionOpInterface) - C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( - U, B[D.b, D.k, D.n] - ) - - @linalg_structured_op def batch_matmul_transpose_a( A=TensorDef(T1, Batch, S.K, S.M), diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir index a3611b8e4ec62..0ec71c35497b1 100644 --- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir @@ -999,6 +999,31 @@ func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7 // ----- +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + +// CHECK-LABEL: func.func @batch_matmul( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x3x5xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<2x5x7xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<2x3x7xf32>) -> tensor<2x3x7xf32> { +// CHECK: %[[VAL_3:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : tensor<2x3x5xf32>, tensor<2x5x7xf32>) outs(%[[VAL_2]] : tensor<2x3x7xf32>) { +// CHECK: arith.mulf +// CHECK: arith.addf + +func.func @batch_matmul(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>, %arg2: tensor<2x3x7xf32>) -> tensor<2x3x7xf32> { + %0 = linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1: tensor<2x3x5xf32>, tensor<2x5x7xf32>) + outs(%arg2: tensor<2x3x7xf32>) -> tensor<2x3x7xf32> + return %0 : tensor<2x3x7xf32> +} + +// ----- + // CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> // CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> // CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index 0ea805fef5361..cff741e75077e 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -1258,3 +1258,204 @@ func.func @winograd_output_transform_output_width(%arg0: tensor<6x6x3x3x2x2xf32> %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x3x3x2x2xf32>) outs(%arg1 : tensor<2x12x11x2xf32>) -> tensor<2x12x11x2xf32> return %0 : tensor<2x12x11x2xf32> } + +// ----- + +func.func @indexing_map_size_mismatch_batch_matmul(%arg0: memref, + %arg1: memref, %arg2: memref) { + // expected-error @+1 {{Indexing_map attribute must have 3 affine maps}} + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> + ] + ins(%arg0, %arg1 : memref, memref) + outs(%arg2: memref) + return +} + +// ----- + +func.func @indexing_map_size_one_batch_matmul(%arg0: memref, + %arg1: memref, %arg2: memref) { + // expected-error @+1 {{Indexing_map attribute must have 3 affine maps}} + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> + ] + ins(%arg0, %arg1 : memref, memref) + outs(%arg2: memref) + return +} + +// ----- + +func.func @missing_indexing_map_batch_matmul(%arg0: memref, %arg1: memref, %arg2: memref) { + // expected-error @+1 {{expected attribute value}} + linalg.batch_matmul indexing_maps = [ + , + affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref, memref) + outs(%arg2 :memref) + return +} + +// ----- + +func.func @invalid_dim_expr_batch_matmul_a(%arg0: memref, %arg1: memref, %arg2: memref) { + // expected-error @+1 {{Unexpected result dim expression (outside the set of default result dims)}} + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref, memref) outs(%arg2 :memref) + return +} + +// ----- + +func.func @invalid_dim_expr_batch_matmul_b(%arg0: memref, %arg1: memref, %arg2: memref) { + // expected-error @+1 {{Unexpected result dim expression (outside the set of default result dims)}} + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref, memref) outs(%arg2 :memref) + return +} + +// ----- + +func.func @invalid_bcast_batch_matmul_a(%arg0: memref, %arg1: memref, %arg2: memref) { + // expected-error @+1 {{'linalg.batch_matmul' op Invalid broadcast requested}} + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0)>, + affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref, memref) outs(%arg2: memref) + return +} + +// ----- + +func.func @invalid_multi_dim_bcast_expr_batch_matmul_a(%arg0: memref, %arg1: memref, %arg2: memref) { + // expected-error @+1 {{'linalg.batch_matmul' op Invalid broadcast requested}} + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref, memref) outs(%arg2: memref) + return +} + +// ----- + +func.func @invalid_multi_dim_bcast_expr_batch_matmul_b(%arg0: memref, %arg1: memref, %arg2: memref) { + // expected-error @+1 {{'linalg.batch_matmul' op Invalid broadcast requested}} + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d3, d0)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref, memref) outs(%arg2: memref) + return +} + +// ----- + +func.func @invalid_bcast_batch_matmul_b(%arg0: memref, %arg1: memref, %arg2: memref) { + // expected-error @+1 {{'linalg.batch_matmul' op Invalid broadcast requested}} + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref, memref) outs(%arg2: memref) + return +} + +// ----- + +func.func @invalid_batch_dim_batch_matmul_a(%arg0: memref, %arg1: memref, %arg2: memref) { + // expected-error @+1 {{'linalg.batch_matmul' op Invalid batch dimension expression}} + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref, memref) outs(%arg2 :memref) + return +} + +// ----- + +func.func @invalid_batch_dim_batch_matmul_b(%arg0: memref, %arg1: memref, %arg2: memref) { + // expected-error @+1 {{'linalg.batch_matmul' op Invalid batch dimension expression}} + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d2, d3, d0)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref, memref) outs(%arg2 :memref) + return +} + +// ----- + +func.func @invalid_A_map_result_num_batch_matmul(%arg0: memref, %arg1: memref, %arg2: memref) { + // expected-error @+1 {{'linalg.batch_matmul' op no. of result dim expressions exceeds 3.}} + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1: memref, memref) + outs(%arg2: memref) + return +} + +// ----- + +func.func @invalid_B_map_result_num_batch_matmul(%arg0: memref, %arg1: memref, %arg2: memref) { + // expected-error @+1 {{'linalg.batch_matmul' op no. of result dim expressions exceeds 3.}} + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d3, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1: memref, memref) + outs(%arg2: memref) + return +} + +// ----- + +func.func @invalid_C_map_result_num_batch_matmul(%arg0: memref, %arg1: memref, %arg2: memref) { + // expected-error @+1 {{'linalg.batch_matmul' op expects 3 dims, but got (2).}} + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d1, d2)> + ] + ins(%arg0, %arg1: memref, memref) + outs(%arg2: memref) + return +} + +// ----- + +func.func @invalid_C_map_result_dim_batch_matmul(%arg0: memref, %arg1: memref, %arg2: memref) { + // expected-error @+1 {{'linalg.batch_matmul' op Invalid output map result dimension.}} + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> + ] + ins(%arg0, %arg1: memref, memref) + outs(%arg2: memref) + return +} diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir index 578d24a550b08..ed8683522c74a 100644 --- a/mlir/test/Dialect/Linalg/named-ops.mlir +++ b/mlir/test/Dialect/Linalg/named-ops.mlir @@ -1485,6 +1485,154 @@ func.func @matmul_transpose_b(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %a // ----- +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + +// CHECK-LABEL: func.func @batch_matmul_bcast_k_to_fill_missing_dims_A( +// CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: memref<2x5x7xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) { +// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] +// CHECK: return +// CHECK: } +func.func @batch_matmul_bcast_k_to_fill_missing_dims_A(%arg0: memref<5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) { + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>) outs(%arg2: memref<2x3x7xf32>) + return +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + +// CHECK-LABEL: func.func @batch_matmul_bcast_batch_dim_A( +// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: memref<2x5x7xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) { +// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] +// CHECK: return +// CHECK: } +func.func @batch_matmul_bcast_batch_dim_A(%arg0: memref<3x5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) { + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x5x7xf32>) outs(%arg2: memref<2x3x7xf32>) + return +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + +// CHECK-LABEL: func.func @batch_matmul_bcast_batch_and_n_dim_B( +// CHECK-SAME: %[[VAL_0:.*]]: memref<2x3x5xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: memref<5xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) { +// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] +// CHECK: return +// CHECK: } +func.func @batch_matmul_bcast_batch_and_n_dim_B(%arg0: memref<2x3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<2x3x7xf32>) { + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<5xf32>) outs(%arg2: memref<2x3x7xf32>) + return +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + +// CHECK-LABEL: func.func @batch_matmul_bcast_batch_dim_B( +// CHECK-SAME: %[[VAL_0:.*]]: memref<2x3x5xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: memref<5x7xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) { +// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] +// CHECK: return +// CHECK: } + +func.func @batch_matmul_bcast_batch_dim_B(%arg0: memref<2x3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<2x3x7xf32>) { + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<5x7xf32>) outs(%arg2: memref<2x3x7xf32>) + return +} + +// ----- + +// CHECK-LABEL: func @batch_matmul_explicit_transpose_A +// CHECK: linalg.batch_matmul +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x5x3xf32>, memref<2x5x7xf32>) +// CHECK-SAME: outs(%{{.+}} : memref<2x3x7xf32>) +func.func @batch_matmul_explicit_transpose_A(%arg0: memref<2x5x3xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) { + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, + affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref<2x5x3xf32>, memref<2x5x7xf32>) outs(%arg2: memref<2x3x7xf32>) + return +} + +// ----- + +// CHECK-LABEL: func @batch_matmul_explicit_transpose_B +// CHECK: linalg.batch_matmul +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x3x5xf32>, memref<2x7x5xf32>) +// CHECK-SAME: outs(%{{.+}} : memref<2x3x7xf32>) +func.func @batch_matmul_explicit_transpose_B(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<2x3x7xf32>) { + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<2x3x7xf32>) + return +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + +// CHECK-LABEL: func.func @batch_matmul_bcast_A_transpose_B( +// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: memref<2x7x5xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) { +// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x7x5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] +// CHECK: return +// CHECK: } +func.func @batch_matmul_bcast_A_transpose_B(%arg0: memref<3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<2x3x7xf32>) { + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<2x3x7xf32>) + return +} + +// ----- + // CHECK-LABEL: func @batchmatmul_transpose_a // CHECK: linalg.batch_matmul_transpose_a // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x5x3xf32>, memref<2x5x7xf32>) diff --git a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir index ebdbe70ff46eb..c68a6362f52c5 100644 --- a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir +++ b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir @@ -35,6 +35,23 @@ func.func @singleton_batch_matmul_memref(%arg0 : memref<1x?x?xf32>, %arg1 : memr // ----- +func.func @negative_singleton_batch_matmul_to_matmul_memref(%arg0 : memref<1x?x?xf32>, %arg1 : memref<1x?x?xf32>, %arg2: memref<1x?x?xf32>) { + // CHECK-LABEL: @negative_singleton_batch_matmul_to_matmul_memref + // CHECK-NOT: collapse_shape + // CHECK-NOT: linalg.matmul + // CHECK-NOT: expand_shape + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref<1x?x?xf32>, memref<1x?x?xf32>) + outs(%arg2 : memref<1x?x?xf32>) + return +} + +// ----- + func.func @singleton_batch_matvec(%arg0 : tensor<1x128x512xf32>, %arg1 : tensor<1x512xf32>, %arg2: tensor<1x128xf32>) -> tensor<1x128xf32> { // CHECK-LABEL: @singleton_batch_matvec // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x128x512xf32> @@ -135,6 +152,20 @@ func.func @matmul_to_matvec(%arg0: memref, %arg1: memref, %arg // ----- +func.func @negative_matmul_to_matvec(%arg0: memref, %arg1: memref, %arg2: memref) { + // CHECK-LABEL: @negative_matmul_to_matvec + // CHECK-NOT: linalg.matvec + linalg.matmul indexing_maps = [ + affine_map<(d0, d1, d2) -> (d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ] + ins(%arg0, %arg1: memref, memref) outs(%arg2: memref) + return +} + +// ----- + func.func @matmul_to_vecmat_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> { // CHECK-LABEL: @matmul_to_vecmat // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?xf32>