From 4968fa7022fbdba6151822446aab165609d61863 Mon Sep 17 00:00:00 2001 From: mshahid Date: Wed, 13 Nov 2024 01:08:25 -0800 Subject: [PATCH 01/10] [MLIR][Linalg] Introduce broadcast/transpose semantic to 'linalg.batch_matmul' operation. Goals: 1. To add syntax and semantic to 'batch_matmul' without changing any of the existing syntax expectations for current usage. batch_matmul is still just batch_matmul. 2. Move the definition of batch_matmul from linalg OpDsl to tablegen ODS infra. Scope of this patch: To expose broadcast and transpose semantics on the 'batch_matmul'. The broadcast and transpose semantic is as follows: By default 'linalg.batch_matmul' behavior will remain as is. Broadcast and Transpose semantics can be appiled by specifying the explicit attribute 'indexing_maps' as shown below.This is a list attribute, so the list must include all the maps if specified. Example Transpose: ``` linalg.batch_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, //transpose affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> ] ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>) outs(%arg2: memref<2x3x7xf32>) ``` Example Broadcast: ``` linalg.batch_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d3)>, //broadcast affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> ] ins(%arg0, %arg1 : memref<5xf32>,memref<2x5x7xf32>) outs(%arg2: memref<2x3x7xf32>) ``` Example Broadcast and transpose: ``` linalg.batch_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d1, d3)>, //broadcast affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, //transpose affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> ] ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<2x3x7xf32>) ``` --- .../Linalg/IR/LinalgNamedStructuredOps.yaml | 69 ------ .../Dialect/Linalg/IR/LinalgStructuredOps.td | 124 ++++++++++ mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 218 ++++++++++++++++++ .../Linalg/Transforms/DropUnitDims.cpp | 3 +- .../linalg/opdsl/ops/core_named_ops.py | 18 -- .../Dialect/Linalg/generalize-named-ops.mlir | 24 ++ mlir/test/Dialect/Linalg/invalid.mlir | 118 ++++++++++ mlir/test/Dialect/Linalg/named-ops.mlir | 148 ++++++++++++ 8 files changed, 634 insertions(+), 88 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml index b0ea1f7695581..496a323249e85 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -1472,75 +1472,6 @@ structured_op: !LinalgStructuredOpConfig - !ScalarExpression scalar_arg: rhs --- !LinalgOpConfig -metadata: !LinalgOpMetadata - name: batch_matmul - cpp_class_name: BatchMatmulOp - doc: |- - Performs a batched matrix multiplication of two 3D inputs. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - implements: - - LinalgContractionOpInterface -structured_op: !LinalgStructuredOpConfig - args: - - !LinalgOperandDefConfig - name: A - kind: input_tensor - type_var: T1 - shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)> - - !LinalgOperandDefConfig - name: B - kind: input_tensor - type_var: T2 - shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)> - - !LinalgOperandDefConfig - name: C - kind: output_tensor - type_var: U - shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)> - indexing_maps: !LinalgIndexingMapsConfig - static_indexing_maps: - - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)> - - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)> - - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)> - iterator_types: - - parallel - - parallel - - parallel - - reduction - assignments: - - !ScalarAssign - arg: C - value: !ScalarExpression - scalar_fn: - kind: binary - fn_name: add - operands: - - !ScalarExpression - scalar_arg: C - - !ScalarExpression - scalar_fn: - kind: binary - fn_name: mul - operands: - - !ScalarExpression - scalar_fn: - kind: type - fn_name: cast_signed - type_var: U - operands: - - !ScalarExpression - scalar_arg: A - - !ScalarExpression - scalar_fn: - kind: type - fn_name: cast_signed - type_var: U - operands: - - !ScalarExpression - scalar_arg: B ---- !LinalgOpConfig metadata: !LinalgOpMetadata name: batch_matmul_transpose_a cpp_class_name: BatchMatmulTransposeAOp diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index e3d122189f8b7..1d66cee8bd2dc 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -816,6 +816,130 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [ }]; } +//===----------------------------------------------------------------------===// +// Op definition for BatchMatmulOp +//===----------------------------------------------------------------------===// + +def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSizedOperandSegments], + /*extraInterfaces=*/[LinalgContractionOpInterface])> { + + let summary = [{Performs a batched matrix multiplication of two 3D inputs.}]; + let description = [{Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + + Broadcast and Transpose semantics can be appiled by specifying the explicit attribute + 'indexing_maps' as shown below.This is a list attribute, so the list must include all + the maps if specified. + + Example Transpose: + ``` + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose + affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>) + outs(%arg2: memref<2x3x7xf32>) + ``` + + Example Broadcast: + ``` + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast + affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>) + outs(%arg2: memref<2x3x7xf32>) + ``` + + Example Broadcast and transpose: + ``` + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast + affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>) + outs(%arg2: memref<2x3x7xf32>) + ``` +}]; + + let arguments = (ins + Variadic:$inputs, + Variadic:$outputs, + DefaultValuedOptionalAttr:$indexing_maps + ); + let results = (outs Variadic:$result_tensors); + let regions = (region AnyRegion:$region); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder< + (ins "ValueRange":$inputs, "ValueRange":$outputs, + CArg<"ArrayRef", "{}">:$attributes), + [{ + buildBatchMatmulOp($_builder, $_state, std::nullopt, inputs, outputs, + attributes, BatchMatmulOp::getRegionBuilder(), + BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext())); + }]>, + OpBuilder< + (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, + "ValueRange":$outputs, + CArg<"ArrayRef", "{}">:$attributes), + [{ + buildBatchMatmulOp($_builder, $_state, resultTensorTypes, + inputs, outputs, attributes, BatchMatmulOp::getRegionBuilder(), + BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext())); + }]>, + OpBuilder< + (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands, + CArg<"ArrayRef", "{}">:$attributes), + [{ + $_state.addOperands(operands); + $_state.addAttributes(attributes); + $_state.addTypes(resultTensorTypes); + (void)$_state.addRegion(), + BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext()); + }]> + + ]; + let hasCustomAssemblyFormat = 1; + let hasFolder = 1; + let hasVerifier = 1; + + let extraClassDeclaration = structuredOpsBaseDecls # [{ + + SmallVector getIteratorTypesArray(); + static void regionBuilder(ImplicitLocOpBuilder &b, + Block &block, ArrayRef attrs); + static std::function)> + getRegionBuilder() { + return regionBuilder; + } + + /// Returns a list of AffineMap with the typical batch_matmul indexing charactristic. + static SmallVector getDefaultIndexingMaps(MLIRContext *context); + + /// Returns true if the given broadcast map \p bcastMap is valid for this op. + bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true); + + ::mlir::MutableOperandRange getDpsInitsMutable() { + return getOutputsMutable(); + } + + // Generic methods. + static unsigned getNumRegionArgs(); + bool hasDynamicIndexingMaps() { return true; } + std::string getLibraryCallName(); + /// Check if the op has broadcast and/or transpose semantic. Returns true if the + /// user defined indexing maps are not equal to default map. + bool hasUserDefinedMaps(); + }]; +} + + //===----------------------------------------------------------------------===// // Named Linalg ops, implemented as a declarative configurations of generic ops. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 2f1c22b10dd36..373b09e603520 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -200,6 +200,23 @@ static void buildMatmulOp(OpBuilder &b, OperationState &state, attributes, regionBuilder); } +static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, + std::optional resultTensorTypes, + ValueRange inputs, ValueRange outputs, + ArrayRef attributes, + RegionBuilderFn regionBuilder, + ArrayRef indexingMaps) { + // Initialize indexingMaps attribute, for BatchMatmulOp. + SmallVector indexingMapsAttrVal; + indexingMapsAttrVal = + llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute { + return AffineMapAttr::get(map); + }); + state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal)); + return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs, + attributes, regionBuilder); +} + /// Common parsing used for both named structured ops created by ods-gen and by /// manually defined C++ ops. Does not handle regions. static ParseResult @@ -3453,6 +3470,46 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, return success(); } +/// Checks if the given AffineMap represents a valid batch dimension. +/// It checks if the first result dimension is a function of the first +/// dimension. +static bool isValidBatchDim(AffineMap bcastMap) { + assert(bcastMap.getNumResults() == 3 && "Expected three result dim expr."); + AffineExpr exp = bcastMap.getResult(0); + return exp.isFunctionOfDim(0); +} + +/// Verifies the broadcast and transpose semantic sepecified by the explicit +/// indexing map for the BatchMatmulOp \p op for each operand specified by \p +/// opIndex. +static LogicalResult +verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp, + unsigned opIndex) { + SmallVector opIndexingMaps = + batchMatmulOp.getIndexingMapsArray(); + SmallVector defaultIndexingMaps = + batchMatmulOp.getDefaultIndexingMaps(batchMatmulOp->getContext()); + + auto opIndexingMap = opIndexingMaps[opIndex]; + auto defaultIndexingMap = defaultIndexingMaps[opIndex]; + // Check general validity of indexing map results. + if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap)) + return batchMatmulOp->emitOpError() + << "Unexpected dim expression in map result."; + // Check if the requested broadcast is valid. + if (isBroadcasted(opIndexingMap, defaultIndexingMap)) { + if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, opIndex == 0)) { + return batchMatmulOp->emitOpError() << "Invalid broadcast requested."; + } + } else { + if (!isValidBatchDim(opIndexingMap)) { + return batchMatmulOp->emitOpError() + << "Invalid batch dimension expression."; + } + } + return success(); +} + namespace mlir { namespace linalg { @@ -3798,5 +3855,166 @@ Speculation::Speculatability ContractOp::getSpeculatability() { return getGenericSpeculatabilityImpl(cast(getOperation())); } +//===----------------------------------------------------------------------===// +// Implementation of BatchMatmulOp +//===----------------------------------------------------------------------===// +SmallVector +BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) { + AffineExpr d0, d1, d2, d3; + SmallVector indexingMaps; + bindDims(context, d0, d1, d2, d3); + indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context)); + indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context)); + indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d2}, context)); + return indexingMaps; +} + +SmallVector BatchMatmulOp::getIteratorTypesArray() { + return SmallVector{ + utils::IteratorType::parallel, utils::IteratorType::parallel, + utils::IteratorType::parallel, utils::IteratorType::reduction}; +} + +unsigned BatchMatmulOp::getNumRegionArgs() { return 3; } + +std::string BatchMatmulOp::getLibraryCallName() { + return generateLibraryCallName(getOperation()); +} + +/// Check if the op has broadcast and/or transpose semantic. Returns true if +/// the user defined indexing maps are not equal to default map. +bool BatchMatmulOp::hasUserDefinedMaps() { + SmallVector defaultMaps = + getDefaultIndexingMaps(this->getContext()); + SmallVector explicitMaps = getIndexingMapsArray(); + return defaultMaps != explicitMaps; +} + +/// Returns true if the given broadcast map \p bcastMap is valid for this op. +bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) { + assert(bcastMap.getNumResults() < 3 && "Expected single result dim expr."); + bool isValid = false; + enum Indices { batchPos, mPos, nPos, kPos }; + if (bcastMap.getNumResults() == 1) { + AffineExpr exp = bcastMap.getResult(0); + isValid = exp.isFunctionOfDim(kPos); + } else if (bcastMap.getNumResults() == 2) { + AffineExpr exp0 = bcastMap.getResult(0); + AffineExpr exp1 = bcastMap.getResult(1); + isValid = isLHS + ? (exp0.isFunctionOfDim(mPos) && exp1.isFunctionOfDim(kPos)) + : (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos)); + } + return isValid; +} + +void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, + ArrayRef attrs) { + assert(3 > 0 && block.getNumArguments() == 3 && + "BatchMatmulOp regionBuilder expects 3 (>=0) args"); + RegionBuilderHelper helper(b, block); + SmallVector yields; + + Value value1 = + helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(), + block.getArgument(0)); + Value value2 = + helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(), + block.getArgument(1)); + Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2); + Value value4 = + helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3); + yields.push_back(value4); + helper.yieldOutputs(yields); +} + +ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) { + SmallVector indexingMapsAttr; + Attribute mapAttr; + if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) { + if (parser.parseEqual()) + return failure(); + + if (parser.parseLSquare()) + return failure(); + + do { + if (parser.parseAttribute(mapAttr)) + return failure(); + if (!isa(mapAttr)) { + return parser.emitError(parser.getCurrentLocation(), + "expected affine map attribute"); + } + indexingMapsAttr.push_back(mapAttr); + + if (parser.parseOptionalComma()) + break; + } while (true); + + if (parser.parseRSquare()) + return failure(); + } + // Initialize indexingMaps, if not supplied explicitly. + if (indexingMapsAttr.empty()) { + indexingMapsAttr = llvm::map_to_vector( + BatchMatmulOp::getDefaultIndexingMaps(parser.getContext()), + [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); }); + } + result.addAttribute("indexing_maps", + parser.getBuilder().getArrayAttr(indexingMapsAttr)); + + return ::parseNamedStructuredOp(parser, result, + BatchMatmulOp::getNumRegionArgs(), + BatchMatmulOp::getRegionBuilder()); +} + +void BatchMatmulOp::print(OpAsmPrinter &p) { + SmallVector elidedAttrs = { + "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"}; + ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(), + elidedAttrs); + + SmallVector indexingMaps = llvm::map_to_vector( + BatchMatmulOp::getDefaultIndexingMaps(getContext()), + [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); }); + if (!llvm::equal(getIndexingMaps(), indexingMaps)) { + p << " indexing_maps = ["; + llvm::interleaveComma(getIndexingMaps(), p, + [&](Attribute attr) { p.printAttribute(attr); }); + p << "]"; + } +} + +/// Verify the user defined indexing maps. +LogicalResult BatchMatmulOp::verify() { + // Verification of pure batch_matmul is handled by + // verifyStructuredOpInterface(). + if (!hasUserDefinedMaps()) + return success(); + + for (unsigned opIndex = 0; opIndex < 2; opIndex++) { + if (failed(verifyExtendedBatchMatmulSemantic(*this, opIndex))) + return failure(); + } + return success(); +} + +LogicalResult BatchMatmulOp::fold(FoldAdaptor, + SmallVectorImpl &) { + return memref::foldMemRefCast(*this); +} + +void BatchMatmulOp::getEffects( + SmallVectorImpl> + &effects) { + if (hasPureTensorSemantics()) + return; + getGenericEffectsImpl(effects, cast(getOperation())); +} + +Speculation::Speculatability BatchMatmulOp::getSpeculatability() { + return getGenericSpeculatabilityImpl(cast(getOperation())); +} + } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index 9b97865990bfd..a5d4c7fe9908c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -935,7 +935,8 @@ struct RankReduceContractionOps : OpRewritePattern { loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs}, ValueRange{collapsedInit}); for (auto attr : contractionOp->getAttrs()) { - if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName) + if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName || + attr.getName() == "indexing_maps") continue; collapsedOp->setAttr(attr.getName(), attr.getValue()); } diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index c95cd5eecfffc..040663c882a08 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -484,24 +484,6 @@ def batch_mmt4d( ) * TypeFn.cast_signed(TV.AccumType, rhs[D.b, D.n, D.k, D.n0, D.k0]) -@linalg_structured_op -def batch_matmul( - A=TensorDef(T1, Batch, S.M, S.K), - B=TensorDef(T2, Batch, S.K, S.N), - C=TensorDef(U, Batch, S.M, S.N, output=True), -): - """Performs a batched matrix multiplication of two 3D inputs. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - domain(D.b, D.m, D.n, D.k) - implements(ContractionOpInterface) - C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( - U, B[D.b, D.k, D.n] - ) - - @linalg_structured_op def batch_matmul_transpose_a( A=TensorDef(T1, Batch, S.K, S.M), diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir index a3611b8e4ec62..38ac230a2dee3 100644 --- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir @@ -999,6 +999,30 @@ func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7 // ----- +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + +// CHECK-LABEL: func.func @batch_matmul( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor, +// CHECK-SAME: %[[VAL_2:.*]]: tensor) -> tensor { +// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : tensor, tensor) outs(%[[VAL_2]] : tensor) { +// CHECK: arith.mulf +// CHECK: arith.addf + +func.func @batch_matmul(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1: tensor, tensor) + outs(%arg2: tensor) -> tensor + return %0 : tensor +} + +// ----- + // CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> // CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> // CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index 0ea805fef5361..9b94d6aaf053e 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -1258,3 +1258,121 @@ func.func @winograd_output_transform_output_width(%arg0: tensor<6x6x3x3x2x2xf32> %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x3x3x2x2xf32>) outs(%arg1 : tensor<2x12x11x2xf32>) -> tensor<2x12x11x2xf32> return %0 : tensor<2x12x11x2xf32> } + +// ----- + +func.func @missing_indexing_map_batch_matmul(%arg0: tensor, %arg1: tensor, %arg2: tensor) { + // expected-error @+1 {{expected attribute value}} + linalg.batch_matmul indexing_maps = [ + , + affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref, memref) + outs(%arg2 :memref) + return +} + +// ----- + +func.func @invalid_dim_expr_batch_matmul_a(%arg0: tensor, %arg1: tensor, %arg2: tensor) { + // expected-error @+1 {{Unexpected dim expression in map result}} + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 :tensor) + return +} + +// ----- + +func.func @invalid_dim_expr_batch_matmul_b(%arg0: tensor, %arg1: tensor, %arg2: tensor) { + // expected-error @+1 {{Unexpected dim expression in map result}} + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 :tensor) + return +} + +// ----- + +func.func @invalid_bcast_batch_matmul_a(%arg0: memref, %arg1: memref, %arg2: memref) { + // expected-error @+1 {{'linalg.batch_matmul' op Invalid broadcast requested}} + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0)>, + affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref, memref) outs(%arg2: memref) + return +} + +// ----- + +func.func @invalid_multi_dim_bcast_expr_batch_matmul_a(%arg0: memref, %arg1: memref, %arg2: memref) { + // expected-error @+1 {{'linalg.batch_matmul' op Invalid broadcast requested}} + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref, memref) outs(%arg2: memref) + return +} + +// ----- + +func.func @invalid_multi_dim_bcast_expr_batch_matmul_b(%arg0: memref, %arg1: memref, %arg2: memref) { + // expected-error @+1 {{'linalg.batch_matmul' op Invalid broadcast requested}} + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d3, d0)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref, memref) outs(%arg2: memref) + return +} + +// ----- + +func.func @invalid_bcast_batch_matmul_b(%arg0: memref, %arg1: memref, %arg2: memref) { + // expected-error @+1 {{'linalg.batch_matmul' op Invalid broadcast requested}} + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref, memref) outs(%arg2: memref) + return +} + +// ----- + +func.func @invalid_batch_dim_batch_matmul_a(%arg0: memref, %arg1: memref, %arg2: memref) { + // expected-error @+1 {{'linalg.batch_matmul' op Invalid batch dimension expression}} + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref, memref) outs(%arg2 :memref) + return +} + +// ----- + +func.func @invalid_batch_dim_batch_matmul_b(%arg0: memref, %arg1: memref, %arg2: memref) { + // expected-error @+1 {{'linalg.batch_matmul' op Invalid batch dimension expression}} + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d2, d3, d0)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref, memref) outs(%arg2 :memref) + return +} diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir index 578d24a550b08..59b7d1926c0b8 100644 --- a/mlir/test/Dialect/Linalg/named-ops.mlir +++ b/mlir/test/Dialect/Linalg/named-ops.mlir @@ -1485,6 +1485,154 @@ func.func @matmul_transpose_b(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %a // ----- +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + +// CHECK-LABEL: func.func @batch_matmul_bcast_batch_and_m_dim_A( +// CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: memref<2x5x7xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) { +// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] +// CHECK: return +// CHECK: } +func.func @batch_matmul_bcast_batch_and_m_dim_A(%arg0: memref<5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) { + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>) outs(%arg2: memref<2x3x7xf32>) + return +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + +// CHECK-LABEL: func.func @batch_matmul_bcast_batch_dim_A( +// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: memref<2x5x7xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) { +// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] +// CHECK: return +// CHECK: } +func.func @batch_matmul_bcast_batch_dim_A(%arg0: memref<3x5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) { + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x5x7xf32>) outs(%arg2: memref<2x3x7xf32>) + return +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + +// CHECK-LABEL: func.func @batch_matmul_bcast_batch_and_n_dim_B( +// CHECK-SAME: %[[VAL_0:.*]]: memref<2x3x5xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: memref<5xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) { +// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] +// CHECK: return +// CHECK: } +func.func @batch_matmul_bcast_batch_and_n_dim_B(%arg0: memref<2x3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<2x3x7xf32>) { + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<5xf32>) outs(%arg2: memref<2x3x7xf32>) + return +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + +// CHECK-LABEL: func.func @batch_matmul_bcast_batch_dim_B( +// CHECK-SAME: %[[VAL_0:.*]]: memref<2x3x5xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: memref<5x7xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) { +// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] +// CHECK: return +// CHECK: } + +func.func @batch_matmul_bcast_batch_dim_B(%arg0: memref<2x3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<2x3x7xf32>) { + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<5x7xf32>) outs(%arg2: memref<2x3x7xf32>) + return +} + +// ----- + +// CHECK-LABEL: func @batch_matmul_explicit_transpose_a +// CHECK: linalg.batch_matmul +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x5x3xf32>, memref<2x5x7xf32>) +// CHECK-SAME: outs(%{{.+}} : memref<2x3x7xf32>) +func.func @batch_matmul_explicit_transpose_a(%arg0: memref<2x5x3xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) { + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, + affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref<2x5x3xf32>, memref<2x5x7xf32>) outs(%arg2: memref<2x3x7xf32>) + return +} + +// ----- + +// CHECK-LABEL: func @batch_matmul_explicit_transpose_b +// CHECK: linalg.batch_matmul +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x3x5xf32>, memref<2x7x5xf32>) +// CHECK-SAME: outs(%{{.+}} : memref<2x3x7xf32>) +func.func @batch_matmul_explicit_transpose_b(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<2x3x7xf32>) { + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<2x3x7xf32>) + return +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + +// CHECK-LABEL: func.func @batch_matmul_bcast_A_transpose_B( +// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: memref<2x7x5xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) { +// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x7x5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] +// CHECK: return +// CHECK: } +func.func @batch_matmul_bcast_A_transpose_B(%arg0: memref<3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<2x3x7xf32>) { + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<2x3x7xf32>) + return +} + +// ----- + // CHECK-LABEL: func @batchmatmul_transpose_a // CHECK: linalg.batch_matmul_transpose_a // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x5x3xf32>, memref<2x5x7xf32>) From 80ea697a0a9e2242686b25d7c00cb6bdc0fbe560 Mon Sep 17 00:00:00 2001 From: mshahid Date: Fri, 10 Jan 2025 10:02:47 -0800 Subject: [PATCH 02/10] -Added output map verification and corresponding tests. -Replaced assert for the count of number of dim expression with proper error reporting and new test case. -Fixed typos. --- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 53 ++++++++++++++----- mlir/test/Dialect/Linalg/invalid.mlir | 66 ++++++++++++++++++++++-- 2 files changed, 101 insertions(+), 18 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 373b09e603520..470502b37eca0 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -3479,7 +3479,18 @@ static bool isValidBatchDim(AffineMap bcastMap) { return exp.isFunctionOfDim(0); } -/// Verifies the broadcast and transpose semantic sepecified by the explicit +/// Checks if the given AffineMap's result dimensions are valid output result +/// dimensions. +static bool isValidOutputResultDim(AffineMap outputMap) { + enum Indices { batchPos, mPos, nPos }; + AffineExpr exp0 = outputMap.getResult(batchPos); + AffineExpr exp1 = outputMap.getResult(mPos); + AffineExpr exp2 = outputMap.getResult(nPos); + return exp0.isFunctionOfDim(batchPos) && exp1.isFunctionOfDim(mPos) && + exp2.isFunctionOfDim(nPos); +} + +/// Verifies the broadcast and transpose semantic specified by the explicit /// indexing map for the BatchMatmulOp \p op for each operand specified by \p /// opIndex. static LogicalResult @@ -3493,19 +3504,35 @@ verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp, auto opIndexingMap = opIndexingMaps[opIndex]; auto defaultIndexingMap = defaultIndexingMaps[opIndex]; // Check general validity of indexing map results. - if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap)) - return batchMatmulOp->emitOpError() - << "Unexpected dim expression in map result."; - // Check if the requested broadcast is valid. - if (isBroadcasted(opIndexingMap, defaultIndexingMap)) { - if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, opIndex == 0)) { - return batchMatmulOp->emitOpError() << "Invalid broadcast requested."; + if (opIndex < 2) { + if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap)) + return batchMatmulOp->emitOpError() + << "Unexpected dim expression in map result."; + // Check if the requested broadcast is valid. + if (isBroadcasted(opIndexingMap, defaultIndexingMap)) { + if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, + opIndex == 0)) { + return batchMatmulOp->emitOpError() << "Invalid broadcast requested."; + } + } else { + // Check for valid number of result dims of input maps. + if (opIndexingMap.getNumResults() != 3) + return batchMatmulOp->emitOpError() + << "no. of result dim expression cannot exceed 3."; + + if (!isValidBatchDim(opIndexingMap)) + return batchMatmulOp->emitOpError() + << "Invalid batch dimension expression."; } } else { - if (!isValidBatchDim(opIndexingMap)) { + // Check for valid number of result dims of output map. + if (opIndexingMap.getNumResults() != 3) return batchMatmulOp->emitOpError() - << "Invalid batch dimension expression."; - } + << "no. of result dim expression cannot exceed 3."; + + if (!isValidOutputResultDim(opIndexingMap)) + return batchMatmulOp->emitOpError() + << "Invalid output map result dimension."; } return success(); } @@ -3910,7 +3937,7 @@ bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) { void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, ArrayRef attrs) { - assert(3 > 0 && block.getNumArguments() == 3 && + assert(block.getNumArguments() == 3 && "BatchMatmulOp regionBuilder expects 3 (>=0) args"); RegionBuilderHelper helper(b, block); SmallVector yields; @@ -3992,7 +4019,7 @@ LogicalResult BatchMatmulOp::verify() { if (!hasUserDefinedMaps()) return success(); - for (unsigned opIndex = 0; opIndex < 2; opIndex++) { + for (unsigned opIndex = 0; opIndex < 3; opIndex++) { if (failed(verifyExtendedBatchMatmulSemantic(*this, opIndex))) return failure(); } diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index 9b94d6aaf053e..b2eac5a19aaaa 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -1261,7 +1261,7 @@ func.func @winograd_output_transform_output_width(%arg0: tensor<6x6x3x3x2x2xf32> // ----- -func.func @missing_indexing_map_batch_matmul(%arg0: tensor, %arg1: tensor, %arg2: tensor) { +func.func @missing_indexing_map_batch_matmul(%arg0: memref, %arg1: memref, %arg2: memref) { // expected-error @+1 {{expected attribute value}} linalg.batch_matmul indexing_maps = [ , @@ -1275,27 +1275,27 @@ func.func @missing_indexing_map_batch_matmul(%arg0: tensor, %arg1: te // ----- -func.func @invalid_dim_expr_batch_matmul_a(%arg0: tensor, %arg1: tensor, %arg2: tensor) { +func.func @invalid_dim_expr_batch_matmul_a(%arg0: memref, %arg1: memref, %arg2: memref) { // expected-error @+1 {{Unexpected dim expression in map result}} linalg.batch_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> ] - ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 :tensor) + ins(%arg0, %arg1 : memref, memref) outs(%arg2 :memref) return } // ----- -func.func @invalid_dim_expr_batch_matmul_b(%arg0: tensor, %arg1: tensor, %arg2: tensor) { +func.func @invalid_dim_expr_batch_matmul_b(%arg0: memref, %arg1: memref, %arg2: memref) { // expected-error @+1 {{Unexpected dim expression in map result}} linalg.batch_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> ] - ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 :tensor) + ins(%arg0, %arg1 : memref, memref) outs(%arg2 :memref) return } @@ -1376,3 +1376,59 @@ func.func @invalid_batch_dim_batch_matmul_b(%arg0: memref, %arg1: mem ins(%arg0, %arg1 : memref, memref) outs(%arg2 :memref) return } + +// ----- + +func.func @invalid_A_map_result_num_batch_matmul(%arg0: memref, %arg1: memref, %arg2: memref) { + // expected-error @+1 {{'linalg.batch_matmul' op no. of result dim expression cannot exceed 3.}} + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1: memref, memref) + outs(%arg2: memref) + return +} + +// ----- + +func.func @invalid_B_map_result_num_batch_matmul(%arg0: memref, %arg1: memref, %arg2: memref) { + // expected-error @+1 {{'linalg.batch_matmul' op no. of result dim expression cannot exceed 3.}} + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d3, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1: memref, memref) + outs(%arg2: memref) + return +} + +// ----- + +func.func @invalid_C_map_result_num_batch_matmul(%arg0: memref, %arg1: memref, %arg2: memref) { + // expected-error @+1 {{'linalg.batch_matmul' op no. of result dim expression cannot exceed 3.}} + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d1, d2)> + ] + ins(%arg0, %arg1: memref, memref) + outs(%arg2: memref) + return +} + +// ----- + +func.func @invalid_C_map_result_dim_batch_matmul(%arg0: memref, %arg1: memref, %arg2: memref) { + // expected-error @+1 {{'linalg.batch_matmul' op Invalid output map result dimension.}} + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> + ] + ins(%arg0, %arg1: memref, memref) + outs(%arg2: memref) + return +} From 0a16982b6a630fc45f174b69954e888428bc16c7 Mon Sep 17 00:00:00 2001 From: mshahid Date: Wed, 15 Jan 2025 01:05:07 -0800 Subject: [PATCH 03/10] *Added checks for extended semantics and exit gracefully in user passes. *Added and udated test cases. *Refactored verification logic. --- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 80 ++++++++++++------- .../Linalg/Transforms/BlockPackMatmul.cpp | 10 +++ .../Linalg/Transforms/DropUnitDims.cpp | 9 +++ .../Linalg/Transforms/TransposeMatmul.cpp | 8 ++ .../Dialect/Linalg/generalize-named-ops.mlir | 15 ++-- mlir/test/Dialect/Linalg/invalid.mlir | 2 +- 6 files changed, 85 insertions(+), 39 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 470502b37eca0..c9eae30e4fb85 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -3474,7 +3474,6 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, /// It checks if the first result dimension is a function of the first /// dimension. static bool isValidBatchDim(AffineMap bcastMap) { - assert(bcastMap.getNumResults() == 3 && "Expected three result dim expr."); AffineExpr exp = bcastMap.getResult(0); return exp.isFunctionOfDim(0); } @@ -3490,6 +3489,48 @@ static bool isValidOutputResultDim(AffineMap outputMap) { exp2.isFunctionOfDim(nPos); } +// Check general validity of input indexing map. +static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp, + AffineMap opIndexingMap, + AffineMap defaultIndexingMap, bool isLHS) { + // Check the result dims are valid. + if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap)) + return batchMatmulOp->emitOpError() + << "Unexpected dim expression in map result."; + + // Check for valid number of result dims of input maps. + if (opIndexingMap.getNumResults() > 3) + return batchMatmulOp->emitOpError() + << "no. of result dim expression cannot exceed 3."; + + // Check if the requested broadcast is valid. + if (isBroadcasted(opIndexingMap, defaultIndexingMap)) { + if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS)) + return batchMatmulOp->emitOpError() << "Invalid broadcast requested."; + } else if (!isValidBatchDim(opIndexingMap)) { + return batchMatmulOp->emitOpError() + << "Invalid batch dimension expression."; + } + return success(); +} + +/// This function checks if the given AffineMap for the output of a +/// BatchMatmulOp has exactly 3 result dimensions and if the output map result +/// dimensions are valid. +static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp, + AffineMap opIndexingMap) { + if (opIndexingMap.getNumResults() != 3) + return batchMatmulOp->emitOpError() + << "expects 3 dims, but got (" << opIndexingMap.getNumResults() + << ")."; + + if (!isValidOutputResultDim(opIndexingMap)) + return batchMatmulOp->emitOpError() + << "Invalid output map result dimension."; + + return success(); +} + /// Verifies the broadcast and transpose semantic specified by the explicit /// indexing map for the BatchMatmulOp \p op for each operand specified by \p /// opIndex. @@ -3503,37 +3544,14 @@ verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp, auto opIndexingMap = opIndexingMaps[opIndex]; auto defaultIndexingMap = defaultIndexingMaps[opIndex]; - // Check general validity of indexing map results. - if (opIndex < 2) { - if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap)) - return batchMatmulOp->emitOpError() - << "Unexpected dim expression in map result."; - // Check if the requested broadcast is valid. - if (isBroadcasted(opIndexingMap, defaultIndexingMap)) { - if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, - opIndex == 0)) { - return batchMatmulOp->emitOpError() << "Invalid broadcast requested."; - } - } else { - // Check for valid number of result dims of input maps. - if (opIndexingMap.getNumResults() != 3) - return batchMatmulOp->emitOpError() - << "no. of result dim expression cannot exceed 3."; - - if (!isValidBatchDim(opIndexingMap)) - return batchMatmulOp->emitOpError() - << "Invalid batch dimension expression."; - } - } else { - // Check for valid number of result dims of output map. - if (opIndexingMap.getNumResults() != 3) - return batchMatmulOp->emitOpError() - << "no. of result dim expression cannot exceed 3."; - if (!isValidOutputResultDim(opIndexingMap)) - return batchMatmulOp->emitOpError() - << "Invalid output map result dimension."; - } + if (opIndex == 2 && failed(verifyOutputMap(batchMatmulOp, opIndexingMap))) + return failure(); + + if (failed(verifyInputMaps(batchMatmulOp, opIndexingMap, defaultIndexingMap, + opIndex == 0))) + return failure(); + return success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp index ed1685a9cb9e6..7f9a0f7a6ca43 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp @@ -138,6 +138,16 @@ transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, FailureOr linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, const ControlBlockPackMatmulFn &controlPackMatmul) { + // Check to not let go the batch_matmul with extended semantic, through this + // transform. + if (auto *batchMatmulOp = dyn_cast(&linalgOp)) { + if (batchMatmulOp->hasUserDefinedMaps()) { + return rewriter.notifyMatchFailure( + *batchMatmulOp, + "only batch_matmul ops with non-extended semantics are supported"); + } + } + if (linalgOp.hasPureBufferSemantics()) return rewriter.notifyMatchFailure(linalgOp, "require tensor semantics"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index a5d4c7fe9908c..a5ebe7628accd 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -906,6 +906,15 @@ struct RankReduceContractionOps : OpRewritePattern { LogicalResult matchAndRewrite(FromOpTy contractionOp, PatternRewriter &rewriter) const override { + // Check to not let go the batch_matmul with extended semantic, through this + // transform. + if (std::is_same::value) { + if (contractionOp.hasUserDefinedMaps()) { + return rewriter.notifyMatchFailure( + contractionOp, + "only batch_matmul ops with non-extended semantics are supported"); + } + } auto loc = contractionOp.getLoc(); auto inputs = contractionOp.getDpsInputs(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp index 6b934f7e8157d..8d12f8a98dbdd 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp @@ -88,6 +88,14 @@ FailureOr mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp batchMatmulOp, bool transposeLHS) { + // Check to not let go the batch_matmul with extended semantic, through this + // transform. + if (batchMatmulOp.hasUserDefinedMaps()) { + return rewriter.notifyMatchFailure( + batchMatmulOp, + "only batch_matmul ops with non-extended semantics are supported"); + } + if (!bufferization::hasTensorSemantics(batchMatmulOp)) return rewriter.notifyMatchFailure( batchMatmulOp, "only matmul ops with tensors are supported"); diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir index 38ac230a2dee3..0ec71c35497b1 100644 --- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir @@ -1004,21 +1004,22 @@ func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7 // CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> // CHECK-LABEL: func.func @batch_matmul( -// CHECK-SAME: %[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor, -// CHECK-SAME: %[[VAL_2:.*]]: tensor) -> tensor { -// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : tensor, tensor) outs(%[[VAL_2]] : tensor) { +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x3x5xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<2x5x7xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<2x3x7xf32>) -> tensor<2x3x7xf32> { +// CHECK: %[[VAL_3:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : tensor<2x3x5xf32>, tensor<2x5x7xf32>) outs(%[[VAL_2]] : tensor<2x3x7xf32>) { // CHECK: arith.mulf // CHECK: arith.addf -func.func @batch_matmul(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @batch_matmul(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>, %arg2: tensor<2x3x7xf32>) -> tensor<2x3x7xf32> { %0 = linalg.batch_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> ] - ins(%arg0, %arg1: tensor, tensor) - outs(%arg2: tensor) -> tensor - return %0 : tensor + ins(%arg0, %arg1: tensor<2x3x5xf32>, tensor<2x5x7xf32>) + outs(%arg2: tensor<2x3x7xf32>) -> tensor<2x3x7xf32> + return %0 : tensor<2x3x7xf32> } // ----- diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index b2eac5a19aaaa..208052c479f4d 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -1408,7 +1408,7 @@ func.func @invalid_B_map_result_num_batch_matmul(%arg0: memref, %arg1 // ----- func.func @invalid_C_map_result_num_batch_matmul(%arg0: memref, %arg1: memref, %arg2: memref) { - // expected-error @+1 {{'linalg.batch_matmul' op no. of result dim expression cannot exceed 3.}} + // expected-error @+1 {{'linalg.batch_matmul' op expects 3 dims, but got (2).}} linalg.batch_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, From 1f44f8f95657f0f0f7595988ef95d214d8aaf9b9 Mon Sep 17 00:00:00 2001 From: mshahid Date: Wed, 15 Jan 2025 03:20:03 -0800 Subject: [PATCH 04/10] *Added logic and tests to verify the size of supplied indexing_map attribute. --- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 4 ++++ mlir/test/Dialect/Linalg/invalid.mlir | 27 ++++++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index c9eae30e4fb85..84db4eaf57623 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -3542,6 +3542,10 @@ verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp, SmallVector defaultIndexingMaps = batchMatmulOp.getDefaultIndexingMaps(batchMatmulOp->getContext()); + if (opIndexingMaps.size() != 3) + return batchMatmulOp->emitOpError() + << "Indexing_map attribute must have 3 affine maps."; + auto opIndexingMap = opIndexingMaps[opIndex]; auto defaultIndexingMap = defaultIndexingMaps[opIndex]; diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index 208052c479f4d..1b8f442e78cf8 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -1261,6 +1261,33 @@ func.func @winograd_output_transform_output_width(%arg0: tensor<6x6x3x3x2x2xf32> // ----- +func.func @indexing_map_size_mismatch_batch_matmul(%arg0: memref, + %arg1: memref, %arg2: memref) { + // expected-error @+1 {{Indexing_map attribute must have 3 affine maps}} + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> + ] + ins(%arg0, %arg1 : memref, memref) + outs(%arg2: memref) + return +} + +// ----- + +func.func @indexing_map_size_one_batch_matmul(%arg0: memref, + %arg1: memref, %arg2: memref) { + // expected-error @+1 {{Indexing_map attribute must have 3 affine maps}} + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> + ] + ins(%arg0, %arg1 : memref, memref) + outs(%arg2: memref) + return +} + +// ----- + func.func @missing_indexing_map_batch_matmul(%arg0: memref, %arg1: memref, %arg2: memref) { // expected-error @+1 {{expected attribute value}} linalg.batch_matmul indexing_maps = [ From 318eecab82ba706eea903691e42041807b7b005f Mon Sep 17 00:00:00 2001 From: mshahid Date: Thu, 16 Jan 2025 06:48:49 -0800 Subject: [PATCH 05/10] *Added logic to update the indexing_map attribute for collapsed MatmulOp. *Updated test names and comments for consistency. --- .../Dialect/Linalg/IR/LinalgStructuredOps.td | 10 ++++---- .../Linalg/Transforms/DropUnitDims.cpp | 24 ++++++++++++++----- .../Linalg/Transforms/TransposeMatmul.cpp | 5 +--- mlir/test/Dialect/Linalg/named-ops.mlir | 8 +++---- 4 files changed, 28 insertions(+), 19 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index 1d66cee8bd2dc..4888d30fb7909 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -828,8 +828,8 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz them to the same data type as the accumulator/output. Broadcast and Transpose semantics can be appiled by specifying the explicit attribute - 'indexing_maps' as shown below.This is a list attribute, so the list must include all - the maps if specified. + 'indexing_maps' as shown below. This is a list attribute, so must include maps for all + arguments if specified. Example Transpose: ``` @@ -845,7 +845,7 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz Example Broadcast: ``` linalg.batch_matmul indexing_maps = [ - affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast + affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> ] @@ -853,7 +853,7 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz outs(%arg2: memref<2x3x7xf32>) ``` - Example Broadcast and transpose: + Example Broadcast and Transpose: ``` linalg.batch_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast @@ -919,7 +919,7 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz return regionBuilder; } - /// Returns a list of AffineMap with the typical batch_matmul indexing charactristic. + /// Returns a list with default AffineMap(s), i.e. without broadcasts and transpositions. static SmallVector getDefaultIndexingMaps(MLIRContext *context); /// Returns true if the given broadcast map \p bcastMap is valid for this op. diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index a5ebe7628accd..904ad220d5551 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -32,6 +32,7 @@ #include "llvm/ADT/SetVector.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" +#include namespace mlir { #define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMSPASS @@ -908,11 +909,11 @@ struct RankReduceContractionOps : OpRewritePattern { PatternRewriter &rewriter) const override { // Check to not let go the batch_matmul with extended semantic, through this // transform. - if (std::is_same::value) { + if (std::is_same::value || + std::is_same::value) { if (contractionOp.hasUserDefinedMaps()) { return rewriter.notifyMatchFailure( - contractionOp, - "only batch_matmul ops with non-extended semantics are supported"); + contractionOp, "ops with user-defined maps are not supported"); } } @@ -944,10 +945,21 @@ struct RankReduceContractionOps : OpRewritePattern { loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs}, ValueRange{collapsedInit}); for (auto attr : contractionOp->getAttrs()) { - if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName || - attr.getName() == "indexing_maps") + if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName) continue; - collapsedOp->setAttr(attr.getName(), attr.getValue()); + + // Update the indexing_maps attribute for the collapsed MatmulOp. + if (attr.getName() == "indexing_maps" && + std::is_same::value && + std::is_same::value) { + SmallVector indexingMapsAttr = llvm::map_to_vector( + MatmulOp::getDefaultIndexingMaps(rewriter.getContext()), + [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); }); + collapsedOp->setAttr(attr.getName(), + rewriter.getArrayAttr(indexingMapsAttr)); + } else { + collapsedOp->setAttr(attr.getName(), attr.getValue()); + } } auto results = contractionOp.getResults(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp index 8d12f8a98dbdd..e624f589917d1 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp @@ -88,12 +88,9 @@ FailureOr mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp batchMatmulOp, bool transposeLHS) { - // Check to not let go the batch_matmul with extended semantic, through this - // transform. if (batchMatmulOp.hasUserDefinedMaps()) { return rewriter.notifyMatchFailure( - batchMatmulOp, - "only batch_matmul ops with non-extended semantics are supported"); + batchMatmulOp, "ops with user-defined maps are not supported"); } if (!bufferization::hasTensorSemantics(batchMatmulOp)) diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir index 59b7d1926c0b8..86230e47d537e 100644 --- a/mlir/test/Dialect/Linalg/named-ops.mlir +++ b/mlir/test/Dialect/Linalg/named-ops.mlir @@ -1578,11 +1578,11 @@ func.func @batch_matmul_bcast_batch_dim_B(%arg0: memref<2x3x5xf32>, %arg1: memre // ----- -// CHECK-LABEL: func @batch_matmul_explicit_transpose_a +// CHECK-LABEL: func @batch_matmul_explicit_transpose_A // CHECK: linalg.batch_matmul // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x5x3xf32>, memref<2x5x7xf32>) // CHECK-SAME: outs(%{{.+}} : memref<2x3x7xf32>) -func.func @batch_matmul_explicit_transpose_a(%arg0: memref<2x5x3xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) { +func.func @batch_matmul_explicit_transpose_A(%arg0: memref<2x5x3xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) { linalg.batch_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, @@ -1594,11 +1594,11 @@ func.func @batch_matmul_explicit_transpose_a(%arg0: memref<2x5x3xf32>, %arg1: me // ----- -// CHECK-LABEL: func @batch_matmul_explicit_transpose_b +// CHECK-LABEL: func @batch_matmul_explicit_transpose_B // CHECK: linalg.batch_matmul // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x3x5xf32>, memref<2x7x5xf32>) // CHECK-SAME: outs(%{{.+}} : memref<2x3x7xf32>) -func.func @batch_matmul_explicit_transpose_b(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<2x3x7xf32>) { +func.func @batch_matmul_explicit_transpose_B(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<2x3x7xf32>) { linalg.batch_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, From f69611650f17eaa4248639566be96d62c4f4e7be Mon Sep 17 00:00:00 2001 From: mshahid Date: Tue, 21 Jan 2025 23:03:05 -0800 Subject: [PATCH 06/10] *Added logic to ensure the indexing_map attribute can be dropped for collapsed contraction op. *Refactored some tests and methods for better naming, comments and readability. --- .../Dialect/Linalg/IR/LinalgInterfaces.td | 5 +- .../Dialect/Linalg/IR/LinalgStructuredOps.td | 6 +- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 55 ++++++++----------- .../Linalg/Transforms/DropUnitDims.cpp | 28 ++-------- mlir/test/Dialect/Linalg/invalid.mlir | 8 +-- mlir/test/Dialect/Linalg/named-ops.mlir | 4 +- 6 files changed, 42 insertions(+), 64 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index 244db23925ab3..98a5fd278a997 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -710,7 +710,10 @@ def LinalgStructuredInterface >, InterfaceMethod< /*desc=*/[{ - Return true if the user has supplied an explicit indexing maps for this op. + Returns true if the user has supplied explicit indexing maps that are + different from default indexing maps for this op. Returns `false` otherwise. + Note, if the user define maps that are identical to the default maps, + this method returns `false`. }], /*retTy=*/"bool", /*methodName=*/"hasUserDefinedMaps", diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index 4888d30fb7909..110ed7d2fc00e 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -674,8 +674,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [ static unsigned getNumRegionArgs(); std::string getLibraryCallName(); bool hasDynamicIndexingMaps(); - /// Check if the op has broadcast and/or transpose semantic. Returns true if the - /// user defined indexing maps are not equal to default map. + /// Returns true if the user defined indexing maps are not equal to default maps. bool hasUserDefinedMaps(); }]; } @@ -933,8 +932,7 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz static unsigned getNumRegionArgs(); bool hasDynamicIndexingMaps() { return true; } std::string getLibraryCallName(); - /// Check if the op has broadcast and/or transpose semantic. Returns true if the - /// user defined indexing maps are not equal to default map. + /// Returns true if the user defined indexing maps are not equal to default maps. bool hasUserDefinedMaps(); }]; } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 84db4eaf57623..deee68e5f6828 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -3426,11 +3426,10 @@ Operation *LinalgDialect::materializeConstant(OpBuilder &builder, return arith::ConstantOp::materialize(builder, value, type, loc); } -/// Returns true if the result AffineExpr of the \p explicitMap is same as \p -/// defaultMap. -static bool isValidResultDimExprs(AffineMap explictMap, AffineMap defaultMap) { - auto explicitRange = explictMap.getResults(); - auto defaultRange = defaultMap.getResults(); +// Returns true if the result expression of `subMap` are a subset of `fullMap`. +static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap) { + auto explicitRange = subMap.getResults(); + auto defaultRange = fullMap.getResults(); DenseSet explicitSet(explicitRange.begin(), explicitRange.end()); DenseSet defaultSet(defaultRange.begin(), defaultRange.end()); llvm::set_union(explicitSet, defaultSet); @@ -3455,7 +3454,7 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, auto opIndexingMap = opIndexingMaps[opIndex]; auto defaultIndexingMap = defaultIndexingMaps[opIndex]; // Check general validity of indexing map results. - if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap)) + if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap)) return matmulOp->emitOpError() << "Unexpected dim expression in map result."; @@ -3470,44 +3469,31 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, return success(); } -/// Checks if the given AffineMap represents a valid batch dimension. -/// It checks if the first result dimension is a function of the first -/// dimension. -static bool isValidBatchDim(AffineMap bcastMap) { - AffineExpr exp = bcastMap.getResult(0); - return exp.isFunctionOfDim(0); -} - -/// Checks if the given AffineMap's result dimensions are valid output result -/// dimensions. -static bool isValidOutputResultDim(AffineMap outputMap) { - enum Indices { batchPos, mPos, nPos }; - AffineExpr exp0 = outputMap.getResult(batchPos); - AffineExpr exp1 = outputMap.getResult(mPos); - AffineExpr exp2 = outputMap.getResult(nPos); - return exp0.isFunctionOfDim(batchPos) && exp1.isFunctionOfDim(mPos) && - exp2.isFunctionOfDim(nPos); -} - // Check general validity of input indexing map. static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS) { // Check the result dims are valid. - if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap)) + if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap)) return batchMatmulOp->emitOpError() - << "Unexpected dim expression in map result."; + << "Unexpected result dim expression (outside the set of default " + "result dims)."; // Check for valid number of result dims of input maps. if (opIndexingMap.getNumResults() > 3) return batchMatmulOp->emitOpError() - << "no. of result dim expression cannot exceed 3."; + << "no. of result dim expressions exceeds 3."; + + auto hasValidBatchDim = [](AffineMap map) { + AffineExpr batchDim = map.getResult(0); + return batchDim.isFunctionOfDim(0); + }; // Check if the requested broadcast is valid. if (isBroadcasted(opIndexingMap, defaultIndexingMap)) { if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS)) return batchMatmulOp->emitOpError() << "Invalid broadcast requested."; - } else if (!isValidBatchDim(opIndexingMap)) { + } else if (!hasValidBatchDim(opIndexingMap)) { return batchMatmulOp->emitOpError() << "Invalid batch dimension expression."; } @@ -3524,7 +3510,13 @@ static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp, << "expects 3 dims, but got (" << opIndexingMap.getNumResults() << ")."; - if (!isValidOutputResultDim(opIndexingMap)) + auto areValidOutputResultDim = [](AffineMap outputMap) { + return outputMap.getResult(0).isFunctionOfDim(0) && + outputMap.getResult(1).isFunctionOfDim(1) && + outputMap.getResult(2).isFunctionOfDim(2); + }; + + if (!areValidOutputResultDim(opIndexingMap)) return batchMatmulOp->emitOpError() << "Invalid output map result dimension."; @@ -3941,7 +3933,8 @@ bool BatchMatmulOp::hasUserDefinedMaps() { /// Returns true if the given broadcast map \p bcastMap is valid for this op. bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) { - assert(bcastMap.getNumResults() < 3 && "Expected single result dim expr."); + assert(bcastMap.getNumResults() < 3 && + "Expected less than 3 result dim expr."); bool isValid = false; enum Indices { batchPos, mPos, nPos, kPos }; if (bcastMap.getNumResults() == 1) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index 904ad220d5551..efea4dea66d2e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -907,14 +907,9 @@ struct RankReduceContractionOps : OpRewritePattern { LogicalResult matchAndRewrite(FromOpTy contractionOp, PatternRewriter &rewriter) const override { - // Check to not let go the batch_matmul with extended semantic, through this - // transform. - if (std::is_same::value || - std::is_same::value) { - if (contractionOp.hasUserDefinedMaps()) { - return rewriter.notifyMatchFailure( - contractionOp, "ops with user-defined maps are not supported"); - } + if (contractionOp.hasUserDefinedMaps()) { + return rewriter.notifyMatchFailure( + contractionOp, "ops with user-defined maps are not supported"); } auto loc = contractionOp.getLoc(); @@ -945,21 +940,10 @@ struct RankReduceContractionOps : OpRewritePattern { loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs}, ValueRange{collapsedInit}); for (auto attr : contractionOp->getAttrs()) { - if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName) + if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName || + attr.getName() == "indexing_maps") continue; - - // Update the indexing_maps attribute for the collapsed MatmulOp. - if (attr.getName() == "indexing_maps" && - std::is_same::value && - std::is_same::value) { - SmallVector indexingMapsAttr = llvm::map_to_vector( - MatmulOp::getDefaultIndexingMaps(rewriter.getContext()), - [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); }); - collapsedOp->setAttr(attr.getName(), - rewriter.getArrayAttr(indexingMapsAttr)); - } else { - collapsedOp->setAttr(attr.getName(), attr.getValue()); - } + collapsedOp->setAttr(attr.getName(), attr.getValue()); } auto results = contractionOp.getResults(); diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index 1b8f442e78cf8..cff741e75077e 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -1303,7 +1303,7 @@ func.func @missing_indexing_map_batch_matmul(%arg0: memref, %arg1: me // ----- func.func @invalid_dim_expr_batch_matmul_a(%arg0: memref, %arg1: memref, %arg2: memref) { - // expected-error @+1 {{Unexpected dim expression in map result}} + // expected-error @+1 {{Unexpected result dim expression (outside the set of default result dims)}} linalg.batch_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, @@ -1316,7 +1316,7 @@ func.func @invalid_dim_expr_batch_matmul_a(%arg0: memref, %arg1: memr // ----- func.func @invalid_dim_expr_batch_matmul_b(%arg0: memref, %arg1: memref, %arg2: memref) { - // expected-error @+1 {{Unexpected dim expression in map result}} + // expected-error @+1 {{Unexpected result dim expression (outside the set of default result dims)}} linalg.batch_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, @@ -1407,7 +1407,7 @@ func.func @invalid_batch_dim_batch_matmul_b(%arg0: memref, %arg1: mem // ----- func.func @invalid_A_map_result_num_batch_matmul(%arg0: memref, %arg1: memref, %arg2: memref) { - // expected-error @+1 {{'linalg.batch_matmul' op no. of result dim expression cannot exceed 3.}} + // expected-error @+1 {{'linalg.batch_matmul' op no. of result dim expressions exceeds 3.}} linalg.batch_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, @@ -1421,7 +1421,7 @@ func.func @invalid_A_map_result_num_batch_matmul(%arg0: memref, %arg1 // ----- func.func @invalid_B_map_result_num_batch_matmul(%arg0: memref, %arg1: memref, %arg2: memref) { - // expected-error @+1 {{'linalg.batch_matmul' op no. of result dim expression cannot exceed 3.}} + // expected-error @+1 {{'linalg.batch_matmul' op no. of result dim expressions exceeds 3.}} linalg.batch_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2, d3)>, diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir index 86230e47d537e..ed8683522c74a 100644 --- a/mlir/test/Dialect/Linalg/named-ops.mlir +++ b/mlir/test/Dialect/Linalg/named-ops.mlir @@ -1489,14 +1489,14 @@ func.func @matmul_transpose_b(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %a // CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> // CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> -// CHECK-LABEL: func.func @batch_matmul_bcast_batch_and_m_dim_A( +// CHECK-LABEL: func.func @batch_matmul_bcast_k_to_fill_missing_dims_A( // CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>, // CHECK-SAME: %[[VAL_1:.*]]: memref<2x5x7xf32>, // CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) { // CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] // CHECK: return // CHECK: } -func.func @batch_matmul_bcast_batch_and_m_dim_A(%arg0: memref<5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) { +func.func @batch_matmul_bcast_k_to_fill_missing_dims_A(%arg0: memref<5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) { linalg.batch_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, From 8c329f4c1b37ae3e0720e1c88f5cdb25a63a965c Mon Sep 17 00:00:00 2001 From: mshahid Date: Wed, 22 Jan 2025 07:59:42 -0800 Subject: [PATCH 07/10] *Added tests to check DropUnitDim transform is not being applied on contraction Op having user defined indexing_maps. --- .../Linalg/rank-reduce-contraction-ops.mlir | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir index ebdbe70ff46eb..c68a6362f52c5 100644 --- a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir +++ b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir @@ -35,6 +35,23 @@ func.func @singleton_batch_matmul_memref(%arg0 : memref<1x?x?xf32>, %arg1 : memr // ----- +func.func @negative_singleton_batch_matmul_to_matmul_memref(%arg0 : memref<1x?x?xf32>, %arg1 : memref<1x?x?xf32>, %arg2: memref<1x?x?xf32>) { + // CHECK-LABEL: @negative_singleton_batch_matmul_to_matmul_memref + // CHECK-NOT: collapse_shape + // CHECK-NOT: linalg.matmul + // CHECK-NOT: expand_shape + linalg.batch_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ] + ins(%arg0, %arg1 : memref<1x?x?xf32>, memref<1x?x?xf32>) + outs(%arg2 : memref<1x?x?xf32>) + return +} + +// ----- + func.func @singleton_batch_matvec(%arg0 : tensor<1x128x512xf32>, %arg1 : tensor<1x512xf32>, %arg2: tensor<1x128xf32>) -> tensor<1x128xf32> { // CHECK-LABEL: @singleton_batch_matvec // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x128x512xf32> @@ -135,6 +152,20 @@ func.func @matmul_to_matvec(%arg0: memref, %arg1: memref, %arg // ----- +func.func @negative_matmul_to_matvec(%arg0: memref, %arg1: memref, %arg2: memref) { + // CHECK-LABEL: @negative_matmul_to_matvec + // CHECK-NOT: linalg.matvec + linalg.matmul indexing_maps = [ + affine_map<(d0, d1, d2) -> (d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ] + ins(%arg0, %arg1: memref, memref) outs(%arg2: memref) + return +} + +// ----- + func.func @matmul_to_vecmat_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> { // CHECK-LABEL: @matmul_to_vecmat // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?xf32> From ec0e5024cbd6632a89f70d7f2d50bb7e8c1f4779 Mon Sep 17 00:00:00 2001 From: mshahid Date: Thu, 23 Jan 2025 02:24:06 -0800 Subject: [PATCH 08/10] *Updated code comment related to broadcast map check. --- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index deee68e5f6828..47062d60aaed1 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -3458,7 +3458,10 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, return matmulOp->emitOpError() << "Unexpected dim expression in map result."; - // Check if the requested broadcast is valid. + // Check if the user defined map is valid broadcast map. Here broadcast + // indexing maps are defined in context of corresponding default indexing maps + // for the given Op. This way the check becomes very simple i.e just check the + // number of result dims. if (isBroadcasted(opIndexingMap, defaultIndexingMap)) { if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) { return matmulOp->emitOpError() From 237e041b37c424e585cf73e55286e44b913c522c Mon Sep 17 00:00:00 2001 From: mshahid Date: Thu, 23 Jan 2025 04:43:15 -0800 Subject: [PATCH 09/10] *Removed undesired header from DropUnitDims.cpp --- mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index efea4dea66d2e..bd4ffabfbb929 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -32,7 +32,6 @@ #include "llvm/ADT/SetVector.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" -#include namespace mlir { #define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMSPASS From 048a48106bb3a7af15920c938028b8f776576036 Mon Sep 17 00:00:00 2001 From: mshahid Date: Thu, 6 Feb 2025 05:48:21 -0800 Subject: [PATCH 10/10] *Renames few variables and updates few comments --- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 36 +++++++++++------------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 47062d60aaed1..b50931f15826c 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -3436,8 +3436,12 @@ static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap) { return explicitSet == defaultSet; } -/// Returns true if the \p explictMap is broadcasted with respect to the -/// \p defaultMap. +/// Check if the user defined map is valid broadcast map. Here broadcast +/// indexing maps are defined in context of corresponding default indexing maps +/// for the given Op. This way the check becomes very simple i.e just check the +/// number of result dims. +/// Returns true if the explictMap is broadcasted with respect to the +/// defaultMap. static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) { return explictMap.getNumResults() < defaultMap.getNumResults(); } @@ -3458,10 +3462,6 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, return matmulOp->emitOpError() << "Unexpected dim expression in map result."; - // Check if the user defined map is valid broadcast map. Here broadcast - // indexing maps are defined in context of corresponding default indexing maps - // for the given Op. This way the check becomes very simple i.e just check the - // number of result dims. if (isBroadcasted(opIndexingMap, defaultIndexingMap)) { if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) { return matmulOp->emitOpError() @@ -3527,8 +3527,7 @@ static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp, } /// Verifies the broadcast and transpose semantic specified by the explicit -/// indexing map for the BatchMatmulOp \p op for each operand specified by \p -/// opIndex. +/// indexing map for the BatchMatmulOp op for each operand specified by opIndex. static LogicalResult verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp, unsigned opIndex) { @@ -3934,7 +3933,7 @@ bool BatchMatmulOp::hasUserDefinedMaps() { return defaultMaps != explicitMaps; } -/// Returns true if the given broadcast map \p bcastMap is valid for this op. +/// Returns true if the given broadcast map bcastMap is valid for this op. bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) { assert(bcastMap.getNumResults() < 3 && "Expected less than 3 result dim expr."); @@ -3960,16 +3959,15 @@ void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, RegionBuilderHelper helper(b, block); SmallVector yields; - Value value1 = - helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(), - block.getArgument(0)); - Value value2 = - helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(), - block.getArgument(1)); - Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2); - Value value4 = - helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3); - yields.push_back(value4); + auto toType = block.getArgument(2).getType(); + Value castValA = + helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0)); + Value castValB = + helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1)); + Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB); + Value addVal = + helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal); + yields.push_back(addVal); helper.yieldOutputs(yields); }