Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][Linalg] Introduce broadcast/transpose semantic to 'linalg.batc… #122275

Merged
merged 10 commits into from
Feb 6, 2025

Conversation

shahidact
Copy link
Contributor

@shahidact shahidact commented Jan 9, 2025

…h_matmul' operation.

Goals:

  1. To add syntax and semantic to 'batch_matmul' without changing any of the existing syntax expectations for current usage. batch_matmul is still just batch_matmul.

  2. Move the definition of batch_matmul from linalg OpDsl to tablegen ODS infra.

Scope of this patch:
To expose broadcast and transpose semantics on the 'batch_matmul'.

The broadcast and transpose semantic are as follows:

By default, 'linalg.batch_matmul' behavior will remain as is. Broadcast and Transpose semantics can be applied by specifying the explicit attribute 'indexing_maps' as shown below. This is a list attribute, so the list must include all the maps if specified.

Example Transpose:
```
linalg.batch_matmul indexing_maps = [
               affine_map< (d0, d1, d2, d3) -> (d0, d3, d1)>, //transpose
               affine_map< (d0, d1, d2, d3) -> (d0, d3, d2)>,
               affine_map< (d0, d1, d2, d3) -> (d0, d1, d2)>
               ]
               ins (%arg0, %arg1: memref<2x5x3xf32>,memref<2x5x7xf32>)
               outs (%arg2: memref<2x3x7xf32>)
```

Example Broadcast:
```
linalg.batch_matmul indexing_maps = [
                   affine_map< (d0, d1, d2, d3) -> (d3)>,  //broadcast
                   affine_map< (d0, d1, d2, d3) -> (d0, d3, d2)>,
                   affine_map< (d0, d1, d2, d3) -> (d0, d1, d2)>
                 ]
                 ins (%arg0, %arg1: memref<5xf32>,memref<2x5x7xf32>)
                 outs (%arg2: memref<2x3x7xf32>)
```

Example Broadcast and transpose:
```
linalg.batch_matmul indexing_maps = [
                   affine_map< (d0, d1, d2, d3) -> (d1, d3)>,     //broadcast
                   affine_map< (d0, d1, d2, d3) -> (d0, d2, d3)>, //transpose
                   affine_map< (d0, d1, d2, d3) -> (d0, d1, d2)>
                 ]
                 ins (%arg0, %arg1: memref<3x5xf32>, memref<2x7x5xf32>)
                 outs (%arg2: memref<2x3x7xf32>)
```

RFCs and related PR:
https://discourse.llvm.org/t/rfc-linalg-opdsl-constant-list-attribute-definition/80149
https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863
https://discourse.llvm.org/t/rfc-mlir-linalg-operation-tree/83586
#115319

@llvmbot
Copy link
Member

llvmbot commented Jan 9, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Md Asghar Ahmad Shahid (shahidact)

Changes

…h_matmul' operation.

Goals:

  1. To add syntax and semantic to 'batch_matmul' without changing any of the existing syntax expectations for current usage. batch_matmul is still just batch_matmul.

  2. Move the definition of batch_matmul from linalg OpDsl to tablegen ODS infra.

Scope of this patch:
To expose broadcast and transpose semantics on the 'batch_matmul'.

The broadcast and transpose semantic are as follows:

By default, 'linalg.batch_matmul' behavior will remain as is. Broadcast and Transpose semantics can be applied by specifying the explicit attribute 'indexing_maps' as shown below. This is a list attribute, so the list must include all the maps if specified.

Example Transpose:
```
linalg.batch_matmul indexing_maps = [
               affine_map&lt; (d0, d1, d2, d3) -&gt; (d0, d3, d1)&gt;, //transpose
               affine_map&lt; (d0, d1, d2, d3) -&gt; (d0, d3, d2)&gt;,
               affine_map&lt; (d0, d1, d2, d3) -&gt; (d0, d1, d2)&gt;
               ]
               ins (%arg0, %arg1: memref&lt;2x5x3xf32&gt;,memref&lt;2x5x7xf32&gt;)
               outs (%arg2: memref&lt;2x3x7xf32&gt;)
```

Example Broadcast:
```
linalg.batch_matmul indexing_maps = [
                   affine_map&lt; (d0, d1, d2, d3) -&gt; (d3)&gt;,  //broadcast
                   affine_map&lt; (d0, d1, d2, d3) -&gt; (d0, d3, d2)&gt;,
                   affine_map&lt; (d0, d1, d2, d3) -&gt; (d0, d1, d2)&gt;
                 ]
                 ins (%arg0, %arg1: memref&lt;5xf32&gt;,memref&lt;2x5x7xf32&gt;)
                 outs (%arg2: memref&lt;2x3x7xf32&gt;)
```

Example Broadcast and transpose:
```
linalg.batch_matmul indexing_maps = [
                   affine_map&lt; (d0, d1, d2, d3) -&gt; (d1, d3)&gt;,     //broadcast
                   affine_map&lt; (d0, d1, d2, d3) -&gt; (d0, d2, d3)&gt;, //transpose
                   affine_map&lt; (d0, d1, d2, d3) -&gt; (d0, d1, d2)&gt;
                 ]
                 ins (%arg0, %arg1: memref&lt;3x5xf32&gt;, memref&lt;2x7x5xf32&gt;)
                 outs (%arg2: memref&lt;2x3x7xf32&gt;)
```

Patch is 35.78 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/122275.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml (-69)
  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+124)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+217)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp (+2-1)
  • (modified) mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py (-18)
  • (modified) mlir/test/Dialect/Linalg/generalize-named-ops.mlir (+23)
  • (modified) mlir/test/Dialect/Linalg/invalid.mlir (+118)
  • (modified) mlir/test/Dialect/Linalg/named-ops.mlir (+148)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index b0ea1f76955816..496a323249e852 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1472,75 +1472,6 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_arg: rhs
 --- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: batch_matmul
-  cpp_class_name: BatchMatmulOp
-  doc: |-
-    Performs a batched matrix multiplication of two 3D inputs.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-  implements:
-  - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: A
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
-  - !LinalgOperandDefConfig
-    name: B
-    kind: input_tensor
-    type_var: T2
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
-  - !LinalgOperandDefConfig
-    name: C
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
-  iterator_types:
-  - parallel
-  - parallel
-  - parallel
-  - reduction
-  assignments:
-  - !ScalarAssign
-    arg: C
-    value: !ScalarExpression
-      scalar_fn:
-        kind: binary
-        fn_name: add
-        operands:
-        - !ScalarExpression
-          scalar_arg: C
-        - !ScalarExpression
-          scalar_fn:
-            kind: binary
-            fn_name: mul
-            operands:
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                fn_name: cast_signed
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: A
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                fn_name: cast_signed
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: B
---- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: batch_matmul_transpose_a
   cpp_class_name: BatchMatmulTransposeAOp
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index fff4048ee125e0..47b871aa322309 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -680,6 +680,130 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
     }];
 }
 
+//===----------------------------------------------------------------------===//
+// Op definition for BatchMatmulOp
+//===----------------------------------------------------------------------===//
+
+def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSizedOperandSegments],
+  /*extraInterfaces=*/[LinalgContractionOpInterface])> {
+    
+  let summary = [{Performs a batched matrix multiplication of two 3D inputs.}];
+  let description = [{Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+
+    Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
+    'indexing_maps' as shown below.This is a list attribute, so the list must include all
+    the maps if specified.
+
+    Example Transpose:
+    ```
+    linalg.batch_matmul indexing_maps = [
+                   affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
+                   affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                   affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                   ]
+                   ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
+                   outs(%arg2: memref<2x3x7xf32>)
+    ```
+
+    Example Broadcast:
+    ```
+    linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d3)>,     // broadcast
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
+                     outs(%arg2: memref<2x3x7xf32>)
+    ```
+
+    Example Broadcast and transpose:
+    ```
+    linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d1, d3)>,     // broadcast
+                       affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
+                     outs(%arg2: memref<2x3x7xf32>)
+    ```
+}];
+
+    let arguments = (ins
+      Variadic<AnyType>:$inputs,
+      Variadic<AnyShaped>:$outputs,
+      DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
+    );
+    let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+    let regions = (region AnyRegion:$region);
+
+    let skipDefaultBuilders = 1;
+    let builders = [
+      OpBuilder<
+      (ins "ValueRange":$inputs, "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildBatchMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
+          attributes, BatchMatmulOp::getRegionBuilder(),
+          BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+            "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildBatchMatmulOp($_builder, $_state, resultTensorTypes,
+          inputs, outputs, attributes, BatchMatmulOp::getRegionBuilder(),
+          BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addOperands(operands);
+        $_state.addAttributes(attributes);
+        $_state.addTypes(resultTensorTypes);
+        (void)$_state.addRegion(),
+        BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext());
+      }]>
+      
+    ];
+    let hasCustomAssemblyFormat = 1;
+    let hasFolder = 1;
+    let hasVerifier = 1;
+
+    let extraClassDeclaration = structuredOpsBaseDecls # [{
+
+      SmallVector<utils::IteratorType> getIteratorTypesArray();
+      static void regionBuilder(ImplicitLocOpBuilder &b,
+                                Block &block, ArrayRef<NamedAttribute> attrs);
+      static std::function<void(ImplicitLocOpBuilder &,
+                                Block &, ArrayRef<NamedAttribute>)>
+      getRegionBuilder() {
+        return regionBuilder;
+      }
+
+      /// Returns a list of AffineMap with the typical batch_matmul indexing charactristic.
+      static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
+
+      /// Returns true if the given broadcast map \p bcastMap is valid for this op.
+      bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true);
+
+      ::mlir::MutableOperandRange getDpsInitsMutable() {
+        return getOutputsMutable();
+      }
+
+      // Generic methods.
+      static unsigned getNumRegionArgs();
+      bool hasDynamicIndexingMaps() { return true; }
+      std::string getLibraryCallName();
+      /// Check if the op has broadcast and/or transpose semantic. Returns true if the
+      /// user defined indexing maps are not equal to default map.
+      bool hasUserDefinedMaps();
+    }];
+}
+
+
 //===----------------------------------------------------------------------===//
 // Named Linalg ops, implemented as a declarative configurations of generic ops.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 8973e87c063b33..868892d1e5f5cc 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -203,6 +203,23 @@ static void buildMatmulOp(OpBuilder &b, OperationState &state,
                            attributes, regionBuilder);
 }
 
+static void buildBatchMatmulOp(OpBuilder &b, OperationState &state,
+                               std::optional<TypeRange> resultTensorTypes,
+                               ValueRange inputs, ValueRange outputs,
+                               ArrayRef<NamedAttribute> attributes,
+                               RegionBuilderFn regionBuilder,
+                               ArrayRef<AffineMap> indexingMaps) {
+  // Initialize indexingMaps attribute, for BatchMatmulOp.
+  SmallVector<Attribute, 4> indexingMapsAttrVal;
+  indexingMapsAttrVal =
+      llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
+        return AffineMapAttr::get(map);
+      });
+  state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+  return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
+                           attributes, regionBuilder);
+}
+
 /// Common parsing used for both named structured ops created by ods-gen and by
 /// manually defined C++ ops. Does not handle regions.
 static ParseResult
@@ -3450,6 +3467,46 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
   return success();
 }
 
+/// Checks if the given AffineMap represents a valid batch dimension.
+/// It checks if the first result dimension is a function of the first
+/// dimension.
+static bool isValidBatchDim(AffineMap bcastMap) {
+  assert(bcastMap.getNumResults() == 3 && "Expected three result dim expr.");
+  AffineExpr exp = bcastMap.getResult(0);
+  return exp.isFunctionOfDim(0);
+}
+
+/// Verifies the broadcast and transpose semantic sepecified by the explicit
+/// indexing map for the BatchMatmulOp \p op for each operand specified by \p
+/// opIndex.
+static LogicalResult
+verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp,
+                                  unsigned opIndex) {
+  SmallVector<AffineMap, 3> opIndexingMaps =
+      batchMatmulOp.getIndexingMapsArray();
+  SmallVector<AffineMap, 3> defaultIndexingMaps =
+      batchMatmulOp.getDefaultIndexingMaps(batchMatmulOp->getContext());
+
+  auto opIndexingMap = opIndexingMaps[opIndex];
+  auto defaultIndexingMap = defaultIndexingMaps[opIndex];
+  // Check general validity of indexing map results.
+  if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap))
+    return batchMatmulOp->emitOpError()
+           << "Unexpected dim expression in map result.";
+  // Check if the requested broadcast is valid.
+  if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
+    if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, opIndex == 0)) {
+      return batchMatmulOp->emitOpError() << "Invalid broadcast requested.";
+    }
+  } else {
+    if (!isValidBatchDim(opIndexingMap)) {
+      return batchMatmulOp->emitOpError()
+             << "Invalid batch dimension expression.";
+    }
+  }
+  return success();
+}
+
 namespace mlir {
 namespace linalg {
 
@@ -3611,5 +3668,165 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
 }
 
+//===----------------------------------------------------------------------===//
+// Implementation of BatchMatmulOp
+//===----------------------------------------------------------------------===//
+
+SmallVector<AffineMap>
+BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
+  AffineExpr d0, d1, d2, d3;
+  SmallVector<AffineMap> indexingMaps;
+  bindDims(context, d0, d1, d2, d3);
+  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
+  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
+  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d2}, context));
+  return indexingMaps;
+}
+
+SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
+  return SmallVector<utils::IteratorType>{
+      utils::IteratorType::parallel, utils::IteratorType::parallel,
+      utils::IteratorType::parallel, utils::IteratorType::reduction};
+}
+
+unsigned BatchMatmulOp::getNumRegionArgs() { return 3; }
+
+std::string BatchMatmulOp::getLibraryCallName() {
+  return generateLibraryCallName(getOperation());
+}
+
+/// Check if the op has broadcast and/or transpose semantic. Returns true if
+/// the user defined indexing maps are not equal to default map.
+bool BatchMatmulOp::hasUserDefinedMaps() {
+  SmallVector<AffineMap, 3> defaultMaps =
+      getDefaultIndexingMaps(this->getContext());
+  SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
+  return defaultMaps != explicitMaps;
+}
+
+/// Returns true if the given broadcast map \p bcastMap is valid for this op.
+bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
+  assert(bcastMap.getNumResults() < 3 && "Expected single result dim expr.");
+  bool isValid = false;
+  enum Indices { batchPos, mPos, nPos, kPos };
+  if (bcastMap.getNumResults() == 1) {
+    AffineExpr exp = bcastMap.getResult(0);
+    isValid = exp.isFunctionOfDim(kPos);
+  } else if (bcastMap.getNumResults() == 2) {
+    AffineExpr exp0 = bcastMap.getResult(0);
+    AffineExpr exp1 = bcastMap.getResult(1);
+    isValid = isLHS
+                  ? (exp0.isFunctionOfDim(mPos) && exp1.isFunctionOfDim(kPos))
+                  : (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos));
+  }
+  return isValid;
+}
+
+void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                                  ArrayRef<NamedAttribute> attrs) {
+  assert(3 > 0 && block.getNumArguments() == 3 &&
+         "BatchMatmulOp regionBuilder expects 3 (>=0) args");
+  RegionBuilderHelper helper(b, block);
+  SmallVector<Value> yields;
+
+  Value value1 =
+      helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+                         block.getArgument(0));
+  Value value2 =
+      helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+                         block.getArgument(1));
+  Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
+  Value value4 =
+      helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
+  yields.push_back(value4);
+  helper.yieldOutputs(yields);
+}
+
+ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
+  SmallVector<Attribute, 3> indexingMapsAttr;
+  Attribute mapAttr;
+  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+    if (parser.parseEqual())
+      return failure();
+
+    if (parser.parseLSquare())
+      return failure();
+
+    do {
+      if (parser.parseAttribute(mapAttr))
+        return failure();
+      if (!isa<AffineMapAttr>(mapAttr)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                                "expected affine map attribute");
+      }
+      indexingMapsAttr.push_back(mapAttr);
+
+      if (parser.parseOptionalComma())
+        break;
+    } while (true);
+
+    if (parser.parseRSquare())
+      return failure();
+  }
+  // Initialize indexingMaps, if not supplied explicitly.
+  if (indexingMapsAttr.empty()) {
+    indexingMapsAttr = llvm::map_to_vector(
+        BatchMatmulOp::getDefaultIndexingMaps(parser.getContext()),
+        [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+  }
+  result.addAttribute("indexing_maps",
+                      parser.getBuilder().getArrayAttr(indexingMapsAttr));
+
+  return ::parseNamedStructuredOp(parser, result,
+                                  BatchMatmulOp::getNumRegionArgs(),
+                                  BatchMatmulOp::getRegionBuilder());
+}
+
+void BatchMatmulOp::print(OpAsmPrinter &p) {
+  SmallVector<StringRef, 3> elidedAttrs = {
+      "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
+  ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+                           elidedAttrs);
+
+  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
+      BatchMatmulOp::getDefaultIndexingMaps(getContext()),
+      [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+  if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
+    p << " indexing_maps = [";
+    llvm::interleaveComma(getIndexingMaps(), p,
+                          [&](Attribute attr) { p.printAttribute(attr); });
+    p << "]";
+  }
+}
+
+/// Verify the user defined indexing maps.
+LogicalResult BatchMatmulOp::verify() {
+  // Verification of pure batch_matmul is handled by
+  // verifyStructuredOpInterface().
+  if (!hasUserDefinedMaps())
+    return success();
+
+  for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
+    if (failed(verifyExtendedBatchMatmulSemantic(*this, opIndex)))
+      return failure();
+  }
+  return success();
+}
+
+LogicalResult BatchMatmulOp::fold(FoldAdaptor,
+                                  SmallVectorImpl<OpFoldResult> &) {
+  return memref::foldMemRefCast(*this);
+}
+void BatchMatmulOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  if (hasPureTensorSemantics())
+    return;
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+Speculation::Speculatability BatchMatmulOp::getSpeculatability() {
+  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
 } // namespace linalg
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 9b97865990bfdd..a5d4c7fe9908c5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -935,7 +935,8 @@ struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
         loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs},
         ValueRange{collapsedInit});
     for (auto attr : contractionOp->getAttrs()) {
-      if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
+      if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName ||
+          attr.getName() == "indexing_maps")
         continue;
       collapsedOp->setAttr(attr.getName(), attr.getValue());
     }
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index c95cd5eecfffca..040663c882a086 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -484,24 +484,6 @@ def batch_mmt4d(
     ) * TypeFn.cast_signed(TV.AccumType, rhs[D.b, D.n, D.k, D.n0, D.k0])
 
 
-@linalg_structured_op
-def batch_matmul(
-    A=TensorDef(T1, Batch, S.M, S.K),
-    B=TensorDef(T2, Batch, S.K, S.N),
-    C=TensorDef(U, Batch, S.M, S.N, output=True),
-):
-    """Performs a batched matrix multiplication of two 3D inputs.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-    """
-    domain(D.b, D.m, D.n, D.k)
-    implements(ContractionOpInterface)
-    C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
-        U, B[D.b, D.k, D.n]
-    )
-
-
 @linalg_structured_op
 def batch_matmul_transpose_a(
     A=TensorDef(T1, Batch, S.K, S.M),
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index aba26c35931fd3..638238b5c38a60 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -1002,3 +1002,26 @@ func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7
 
 // -----
 
+// CHECK: #[[$ATTR_...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jan 9, 2025

@llvm/pr-subscribers-mlir

Author: Md Asghar Ahmad Shahid (shahidact)

Changes

…h_matmul' operation.

Goals:

  1. To add syntax and semantic to 'batch_matmul' without changing any of the existing syntax expectations for current usage. batch_matmul is still just batch_matmul.

  2. Move the definition of batch_matmul from linalg OpDsl to tablegen ODS infra.

Scope of this patch:
To expose broadcast and transpose semantics on the 'batch_matmul'.

The broadcast and transpose semantic are as follows:

By default, 'linalg.batch_matmul' behavior will remain as is. Broadcast and Transpose semantics can be applied by specifying the explicit attribute 'indexing_maps' as shown below. This is a list attribute, so the list must include all the maps if specified.

Example Transpose:
```
linalg.batch_matmul indexing_maps = [
               affine_map&lt; (d0, d1, d2, d3) -&gt; (d0, d3, d1)&gt;, //transpose
               affine_map&lt; (d0, d1, d2, d3) -&gt; (d0, d3, d2)&gt;,
               affine_map&lt; (d0, d1, d2, d3) -&gt; (d0, d1, d2)&gt;
               ]
               ins (%arg0, %arg1: memref&lt;2x5x3xf32&gt;,memref&lt;2x5x7xf32&gt;)
               outs (%arg2: memref&lt;2x3x7xf32&gt;)
```

Example Broadcast:
```
linalg.batch_matmul indexing_maps = [
                   affine_map&lt; (d0, d1, d2, d3) -&gt; (d3)&gt;,  //broadcast
                   affine_map&lt; (d0, d1, d2, d3) -&gt; (d0, d3, d2)&gt;,
                   affine_map&lt; (d0, d1, d2, d3) -&gt; (d0, d1, d2)&gt;
                 ]
                 ins (%arg0, %arg1: memref&lt;5xf32&gt;,memref&lt;2x5x7xf32&gt;)
                 outs (%arg2: memref&lt;2x3x7xf32&gt;)
```

Example Broadcast and transpose:
```
linalg.batch_matmul indexing_maps = [
                   affine_map&lt; (d0, d1, d2, d3) -&gt; (d1, d3)&gt;,     //broadcast
                   affine_map&lt; (d0, d1, d2, d3) -&gt; (d0, d2, d3)&gt;, //transpose
                   affine_map&lt; (d0, d1, d2, d3) -&gt; (d0, d1, d2)&gt;
                 ]
                 ins (%arg0, %arg1: memref&lt;3x5xf32&gt;, memref&lt;2x7x5xf32&gt;)
                 outs (%arg2: memref&lt;2x3x7xf32&gt;)
```

Patch is 35.78 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/122275.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml (-69)
  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+124)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+217)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp (+2-1)
  • (modified) mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py (-18)
  • (modified) mlir/test/Dialect/Linalg/generalize-named-ops.mlir (+23)
  • (modified) mlir/test/Dialect/Linalg/invalid.mlir (+118)
  • (modified) mlir/test/Dialect/Linalg/named-ops.mlir (+148)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index b0ea1f76955816..496a323249e852 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1472,75 +1472,6 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_arg: rhs
 --- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: batch_matmul
-  cpp_class_name: BatchMatmulOp
-  doc: |-
-    Performs a batched matrix multiplication of two 3D inputs.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-  implements:
-  - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: A
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
-  - !LinalgOperandDefConfig
-    name: B
-    kind: input_tensor
-    type_var: T2
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
-  - !LinalgOperandDefConfig
-    name: C
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
-  iterator_types:
-  - parallel
-  - parallel
-  - parallel
-  - reduction
-  assignments:
-  - !ScalarAssign
-    arg: C
-    value: !ScalarExpression
-      scalar_fn:
-        kind: binary
-        fn_name: add
-        operands:
-        - !ScalarExpression
-          scalar_arg: C
-        - !ScalarExpression
-          scalar_fn:
-            kind: binary
-            fn_name: mul
-            operands:
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                fn_name: cast_signed
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: A
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                fn_name: cast_signed
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: B
---- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: batch_matmul_transpose_a
   cpp_class_name: BatchMatmulTransposeAOp
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index fff4048ee125e0..47b871aa322309 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -680,6 +680,130 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
     }];
 }
 
+//===----------------------------------------------------------------------===//
+// Op definition for BatchMatmulOp
+//===----------------------------------------------------------------------===//
+
+def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSizedOperandSegments],
+  /*extraInterfaces=*/[LinalgContractionOpInterface])> {
+    
+  let summary = [{Performs a batched matrix multiplication of two 3D inputs.}];
+  let description = [{Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+
+    Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
+    'indexing_maps' as shown below.This is a list attribute, so the list must include all
+    the maps if specified.
+
+    Example Transpose:
+    ```
+    linalg.batch_matmul indexing_maps = [
+                   affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
+                   affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                   affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                   ]
+                   ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
+                   outs(%arg2: memref<2x3x7xf32>)
+    ```
+
+    Example Broadcast:
+    ```
+    linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d3)>,     // broadcast
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
+                     outs(%arg2: memref<2x3x7xf32>)
+    ```
+
+    Example Broadcast and transpose:
+    ```
+    linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d1, d3)>,     // broadcast
+                       affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
+                     outs(%arg2: memref<2x3x7xf32>)
+    ```
+}];
+
+    let arguments = (ins
+      Variadic<AnyType>:$inputs,
+      Variadic<AnyShaped>:$outputs,
+      DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
+    );
+    let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+    let regions = (region AnyRegion:$region);
+
+    let skipDefaultBuilders = 1;
+    let builders = [
+      OpBuilder<
+      (ins "ValueRange":$inputs, "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildBatchMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
+          attributes, BatchMatmulOp::getRegionBuilder(),
+          BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+            "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildBatchMatmulOp($_builder, $_state, resultTensorTypes,
+          inputs, outputs, attributes, BatchMatmulOp::getRegionBuilder(),
+          BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addOperands(operands);
+        $_state.addAttributes(attributes);
+        $_state.addTypes(resultTensorTypes);
+        (void)$_state.addRegion(),
+        BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext());
+      }]>
+      
+    ];
+    let hasCustomAssemblyFormat = 1;
+    let hasFolder = 1;
+    let hasVerifier = 1;
+
+    let extraClassDeclaration = structuredOpsBaseDecls # [{
+
+      SmallVector<utils::IteratorType> getIteratorTypesArray();
+      static void regionBuilder(ImplicitLocOpBuilder &b,
+                                Block &block, ArrayRef<NamedAttribute> attrs);
+      static std::function<void(ImplicitLocOpBuilder &,
+                                Block &, ArrayRef<NamedAttribute>)>
+      getRegionBuilder() {
+        return regionBuilder;
+      }
+
+      /// Returns a list of AffineMap with the typical batch_matmul indexing charactristic.
+      static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
+
+      /// Returns true if the given broadcast map \p bcastMap is valid for this op.
+      bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true);
+
+      ::mlir::MutableOperandRange getDpsInitsMutable() {
+        return getOutputsMutable();
+      }
+
+      // Generic methods.
+      static unsigned getNumRegionArgs();
+      bool hasDynamicIndexingMaps() { return true; }
+      std::string getLibraryCallName();
+      /// Check if the op has broadcast and/or transpose semantic. Returns true if the
+      /// user defined indexing maps are not equal to default map.
+      bool hasUserDefinedMaps();
+    }];
+}
+
+
 //===----------------------------------------------------------------------===//
 // Named Linalg ops, implemented as a declarative configurations of generic ops.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 8973e87c063b33..868892d1e5f5cc 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -203,6 +203,23 @@ static void buildMatmulOp(OpBuilder &b, OperationState &state,
                            attributes, regionBuilder);
 }
 
+static void buildBatchMatmulOp(OpBuilder &b, OperationState &state,
+                               std::optional<TypeRange> resultTensorTypes,
+                               ValueRange inputs, ValueRange outputs,
+                               ArrayRef<NamedAttribute> attributes,
+                               RegionBuilderFn regionBuilder,
+                               ArrayRef<AffineMap> indexingMaps) {
+  // Initialize indexingMaps attribute, for BatchMatmulOp.
+  SmallVector<Attribute, 4> indexingMapsAttrVal;
+  indexingMapsAttrVal =
+      llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
+        return AffineMapAttr::get(map);
+      });
+  state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+  return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
+                           attributes, regionBuilder);
+}
+
 /// Common parsing used for both named structured ops created by ods-gen and by
 /// manually defined C++ ops. Does not handle regions.
 static ParseResult
@@ -3450,6 +3467,46 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
   return success();
 }
 
+/// Checks if the given AffineMap represents a valid batch dimension.
+/// It checks if the first result dimension is a function of the first
+/// dimension.
+static bool isValidBatchDim(AffineMap bcastMap) {
+  assert(bcastMap.getNumResults() == 3 && "Expected three result dim expr.");
+  AffineExpr exp = bcastMap.getResult(0);
+  return exp.isFunctionOfDim(0);
+}
+
+/// Verifies the broadcast and transpose semantic sepecified by the explicit
+/// indexing map for the BatchMatmulOp \p op for each operand specified by \p
+/// opIndex.
+static LogicalResult
+verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp,
+                                  unsigned opIndex) {
+  SmallVector<AffineMap, 3> opIndexingMaps =
+      batchMatmulOp.getIndexingMapsArray();
+  SmallVector<AffineMap, 3> defaultIndexingMaps =
+      batchMatmulOp.getDefaultIndexingMaps(batchMatmulOp->getContext());
+
+  auto opIndexingMap = opIndexingMaps[opIndex];
+  auto defaultIndexingMap = defaultIndexingMaps[opIndex];
+  // Check general validity of indexing map results.
+  if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap))
+    return batchMatmulOp->emitOpError()
+           << "Unexpected dim expression in map result.";
+  // Check if the requested broadcast is valid.
+  if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
+    if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, opIndex == 0)) {
+      return batchMatmulOp->emitOpError() << "Invalid broadcast requested.";
+    }
+  } else {
+    if (!isValidBatchDim(opIndexingMap)) {
+      return batchMatmulOp->emitOpError()
+             << "Invalid batch dimension expression.";
+    }
+  }
+  return success();
+}
+
 namespace mlir {
 namespace linalg {
 
@@ -3611,5 +3668,165 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
 }
 
+//===----------------------------------------------------------------------===//
+// Implementation of BatchMatmulOp
+//===----------------------------------------------------------------------===//
+
+SmallVector<AffineMap>
+BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
+  AffineExpr d0, d1, d2, d3;
+  SmallVector<AffineMap> indexingMaps;
+  bindDims(context, d0, d1, d2, d3);
+  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
+  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
+  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d2}, context));
+  return indexingMaps;
+}
+
+SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
+  return SmallVector<utils::IteratorType>{
+      utils::IteratorType::parallel, utils::IteratorType::parallel,
+      utils::IteratorType::parallel, utils::IteratorType::reduction};
+}
+
+unsigned BatchMatmulOp::getNumRegionArgs() { return 3; }
+
+std::string BatchMatmulOp::getLibraryCallName() {
+  return generateLibraryCallName(getOperation());
+}
+
+/// Check if the op has broadcast and/or transpose semantic. Returns true if
+/// the user defined indexing maps are not equal to default map.
+bool BatchMatmulOp::hasUserDefinedMaps() {
+  SmallVector<AffineMap, 3> defaultMaps =
+      getDefaultIndexingMaps(this->getContext());
+  SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
+  return defaultMaps != explicitMaps;
+}
+
+/// Returns true if the given broadcast map \p bcastMap is valid for this op.
+bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
+  assert(bcastMap.getNumResults() < 3 && "Expected single result dim expr.");
+  bool isValid = false;
+  enum Indices { batchPos, mPos, nPos, kPos };
+  if (bcastMap.getNumResults() == 1) {
+    AffineExpr exp = bcastMap.getResult(0);
+    isValid = exp.isFunctionOfDim(kPos);
+  } else if (bcastMap.getNumResults() == 2) {
+    AffineExpr exp0 = bcastMap.getResult(0);
+    AffineExpr exp1 = bcastMap.getResult(1);
+    isValid = isLHS
+                  ? (exp0.isFunctionOfDim(mPos) && exp1.isFunctionOfDim(kPos))
+                  : (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos));
+  }
+  return isValid;
+}
+
+void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                                  ArrayRef<NamedAttribute> attrs) {
+  assert(3 > 0 && block.getNumArguments() == 3 &&
+         "BatchMatmulOp regionBuilder expects 3 (>=0) args");
+  RegionBuilderHelper helper(b, block);
+  SmallVector<Value> yields;
+
+  Value value1 =
+      helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+                         block.getArgument(0));
+  Value value2 =
+      helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+                         block.getArgument(1));
+  Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
+  Value value4 =
+      helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
+  yields.push_back(value4);
+  helper.yieldOutputs(yields);
+}
+
+ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
+  SmallVector<Attribute, 3> indexingMapsAttr;
+  Attribute mapAttr;
+  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+    if (parser.parseEqual())
+      return failure();
+
+    if (parser.parseLSquare())
+      return failure();
+
+    do {
+      if (parser.parseAttribute(mapAttr))
+        return failure();
+      if (!isa<AffineMapAttr>(mapAttr)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                                "expected affine map attribute");
+      }
+      indexingMapsAttr.push_back(mapAttr);
+
+      if (parser.parseOptionalComma())
+        break;
+    } while (true);
+
+    if (parser.parseRSquare())
+      return failure();
+  }
+  // Initialize indexingMaps, if not supplied explicitly.
+  if (indexingMapsAttr.empty()) {
+    indexingMapsAttr = llvm::map_to_vector(
+        BatchMatmulOp::getDefaultIndexingMaps(parser.getContext()),
+        [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+  }
+  result.addAttribute("indexing_maps",
+                      parser.getBuilder().getArrayAttr(indexingMapsAttr));
+
+  return ::parseNamedStructuredOp(parser, result,
+                                  BatchMatmulOp::getNumRegionArgs(),
+                                  BatchMatmulOp::getRegionBuilder());
+}
+
+void BatchMatmulOp::print(OpAsmPrinter &p) {
+  SmallVector<StringRef, 3> elidedAttrs = {
+      "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
+  ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+                           elidedAttrs);
+
+  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
+      BatchMatmulOp::getDefaultIndexingMaps(getContext()),
+      [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+  if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
+    p << " indexing_maps = [";
+    llvm::interleaveComma(getIndexingMaps(), p,
+                          [&](Attribute attr) { p.printAttribute(attr); });
+    p << "]";
+  }
+}
+
+/// Verify the user defined indexing maps.
+LogicalResult BatchMatmulOp::verify() {
+  // Verification of pure batch_matmul is handled by
+  // verifyStructuredOpInterface().
+  if (!hasUserDefinedMaps())
+    return success();
+
+  for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
+    if (failed(verifyExtendedBatchMatmulSemantic(*this, opIndex)))
+      return failure();
+  }
+  return success();
+}
+
+LogicalResult BatchMatmulOp::fold(FoldAdaptor,
+                                  SmallVectorImpl<OpFoldResult> &) {
+  return memref::foldMemRefCast(*this);
+}
+void BatchMatmulOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  if (hasPureTensorSemantics())
+    return;
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+Speculation::Speculatability BatchMatmulOp::getSpeculatability() {
+  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
 } // namespace linalg
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 9b97865990bfdd..a5d4c7fe9908c5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -935,7 +935,8 @@ struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
         loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs},
         ValueRange{collapsedInit});
     for (auto attr : contractionOp->getAttrs()) {
-      if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
+      if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName ||
+          attr.getName() == "indexing_maps")
         continue;
       collapsedOp->setAttr(attr.getName(), attr.getValue());
     }
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index c95cd5eecfffca..040663c882a086 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -484,24 +484,6 @@ def batch_mmt4d(
     ) * TypeFn.cast_signed(TV.AccumType, rhs[D.b, D.n, D.k, D.n0, D.k0])
 
 
-@linalg_structured_op
-def batch_matmul(
-    A=TensorDef(T1, Batch, S.M, S.K),
-    B=TensorDef(T2, Batch, S.K, S.N),
-    C=TensorDef(U, Batch, S.M, S.N, output=True),
-):
-    """Performs a batched matrix multiplication of two 3D inputs.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-    """
-    domain(D.b, D.m, D.n, D.k)
-    implements(ContractionOpInterface)
-    C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
-        U, B[D.b, D.k, D.n]
-    )
-
-
 @linalg_structured_op
 def batch_matmul_transpose_a(
     A=TensorDef(T1, Batch, S.K, S.M),
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index aba26c35931fd3..638238b5c38a60 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -1002,3 +1002,26 @@ func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7
 
 // -----
 
+// CHECK: #[[$ATTR_...
[truncated]

@banach-space
Copy link
Contributor

Please, could you provide references to relevant RFCs where this was discussed/agreed-on? And include in the summary. Thanks :)

@shahidact
Copy link
Contributor Author

Please, could you provide references to relevant RFCs where this was discussed/agreed-on? And include in the summary. Thanks :)

Added :)

@MaheshRavishankar
Copy link
Contributor

Havent reviewed in detail yet. How is this related to https://discourse.llvm.org/t/mlir-rfc-introduce-linalg-contract/83589 .

@rengolin
Copy link
Member

Havent reviewed in detail yet. How is this related to https://discourse.llvm.org/t/mlir-rfc-introduce-linalg-contract/83589 .

It is not. This is just a cleanup of OpDSL that we agreed last year. Please do not cross the wires here.

@banach-space
Copy link
Contributor

Please, could you provide references to relevant RFCs where this was discussed/agreed-on? And include in the summary. Thanks :)

Added :)

Thanks! @MaheshRavishankar , relevant links have been added in the summary at the top:

RFCs and related PR:
https://discourse.llvm.org/t/rfc-linalg-opdsl-constant-list-attribute-definition/80149
https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863
https://discourse.llvm.org/t/rfc-mlir-linalg-operation-tree/83586
#115319

Looks like the syntax introduced here matches #115319 (unless I missed sth), so this is consistent with the previous patches in the series 👍🏻

One high level (kind) request:

Goals:

To add syntax and semantic to 'batch_matmul' without changing any of the existing syntax expectations for current usage. batch_matmul is still just batch_matmul.

Move the definition of batch_matmul from linalg OpDsl to tablegen ODS infra.

Why not split this into two patches? Those two goals seem orthogonal to me and reviewing small patches is easier (at least on GitHub).

@rengolin
Copy link
Member

Why not split this into two patches? Those two goals seem orthogonal to me and reviewing small patches is easier (at least on GitHub).

From discussions on the previous PR (matmul), there should be no differences to what we did there. The ODS implementation is identical to the OpDSL one except the affine map. All previous tests should pass. So the extra logic that needs to be reviewed is actually a very small portion of the whole patch. If we split, the first patch would be almost as big as this one.

But getting the semantics correct is not trivial either. So splitting could lead to an interim state that is not quite here nor there, and we want to avoid that.

@rengolin
Copy link
Member

Looks like the syntax introduced here matches #115319 (unless I missed sth), so this is consistent with the previous patches in the series 👍🏻

That's the idea. After this, there's a last one for batch_reduce_matmul, then we can start the downstream cleanup to stop using the _transpose_ variants in favour of these ones, and in some point in the future decide to drop them from OpDSL altogether.

@MaheshRavishankar
Copy link
Contributor

Thanks @shahidact looks good to me (i didnt review all the details, but mostly for structure of things and that matches what I expect)

@shahidact
Copy link
Contributor Author

Thanks @shahidact looks good to me (i didnt review all the details, but mostly for structure of things and that matches what I expect)

Thanks a lot.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this, the overall structure is good!

I've skimmed through the easier part and left some minor comments. I'll try to go over the rest in the next day/two.

Copy link
Contributor

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fixes 👍

LGTM % open conversations

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I've reviewed so far LGTM % the remaining comments (please address before landing). I’m relying on @adam-smnk for the rest 😅

I will be OOO for a week, so approving as is to unblock you. Please feel free to merge once all outstanding comments are resolved.

Last but not least, thank you for working on this, @shahidact - this is a non-trivial area, and your work will have a significant impact on Linalg 🙏🏻

// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK: return
// CHECK: }
func.func @batch_matmul_bcast_batch_dim_A(%arg0: memref<3x5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that you have updated @batch_matmul_bcast_k_to_fill_missing_dims_A, but forgot to use similar naming style for other tests. Did you mean @batch_matmul_bcast_m_to_fill_missing_batch_dim_A? Same for other tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean @batch_matmul_bcast_m_to_fill_missing_batch_dim_A?

I meant to broadcast row-column of matrix A across broadcast dim.

…h_matmul' operation.

Goals:
1. To add syntax and semantic to 'batch_matmul' without changing any of the
   existing syntax expectations for current usage. batch_matmul is still
   just batch_matmul.

2. Move the definition of batch_matmul from linalg OpDsl to tablegen ODS
   infra.

Scope of this patch:
To expose broadcast and transpose semantics on the 'batch_matmul'.

The broadcast and transpose semantic is as follows:

By default 'linalg.batch_matmul' behavior will remain as is.
Broadcast and Transpose semantics can be appiled by specifying the
explicit attribute 'indexing_maps' as shown below.This is a list attribute, so the list
must include all the maps if specified.

    Example Transpose:
    ```
    linalg.batch_matmul indexing_maps = [
                   affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, //transpose
                   affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
                   affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
                   ]
                   ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
                   outs(%arg2: memref<2x3x7xf32>)
    ```

    Example Broadcast:
    ```
    linalg.batch_matmul indexing_maps = [
                       affine_map<(d0, d1, d2, d3) -> (d3)>,     //broadcast
                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
                     ]
                     ins(%arg0, %arg1 : memref<5xf32>,memref<2x5x7xf32>)
                     outs(%arg2: memref<2x3x7xf32>)
    ```

    Example Broadcast and transpose:
    ```
    linalg.batch_matmul indexing_maps = [
                       affine_map<(d0, d1, d2, d3) -> (d1, d3)>,     //broadcast
                       affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, //transpose
                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
                     ]
                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
                     outs(%arg2: memref<2x3x7xf32>)
    ```
-Replaced assert for the count of number of dim expression with proper error reporting and new test case.
-Fixed typos.
*Added and udated test cases.
*Refactored verification logic.
…lOp.

*Updated test names and comments for consistency.
…collapsed contraction op.

*Refactored some tests and methods for better naming, comments and readability.
…ontraction Op having user defined indexing_maps.
@rengolin rengolin merged commit f2bca9e into llvm:main Feb 6, 2025
8 checks passed
Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
…llvm#122275)

Goals:
1. To add syntax and semantic to 'batch_matmul' without changing any of
the existing syntax expectations for current usage. batch_matmul is
still just batch_matmul.

2. Move the definition of batch_matmul from linalg OpDsl to tablegen ODS
infra.

Scope of this patch:
To expose broadcast and transpose semantics on the 'batch_matmul'.

The broadcast and transpose semantic are as follows:

By default, 'linalg.batch_matmul' behavior will remain as is. Broadcast
and Transpose semantics can be applied by specifying the explicit
attribute 'indexing_maps' as shown below. This is a list attribute, so
the list must include all the maps if specified.

    Example Transpose:
    ```
    linalg.batch_matmul indexing_maps = [
affine_map< (d0, d1, d2, d3) -> (d0, d3, d1)>, //transpose
                   affine_map< (d0, d1, d2, d3) -> (d0, d3, d2)>,
                   affine_map< (d0, d1, d2, d3) -> (d0, d1, d2)>
                   ]
ins (%arg0, %arg1: memref<2x5x3xf32>,memref<2x5x7xf32>)
                   outs (%arg2: memref<2x3x7xf32>)
    ```

    Example Broadcast:
    ```
    linalg.batch_matmul indexing_maps = [
affine_map< (d0, d1, d2, d3) -> (d3)>, //broadcast
                       affine_map< (d0, d1, d2, d3) -> (d0, d3, d2)>,
                       affine_map< (d0, d1, d2, d3) -> (d0, d1, d2)>
                     ]
                     ins (%arg0, %arg1: memref<5xf32>,memref<2x5x7xf32>)
                     outs (%arg2: memref<2x3x7xf32>)
    ```

    Example Broadcast and transpose:
    ```
    linalg.batch_matmul indexing_maps = [
affine_map< (d0, d1, d2, d3) -> (d1, d3)>, //broadcast
affine_map< (d0, d1, d2, d3) -> (d0, d2, d3)>, //transpose
                       affine_map< (d0, d1, d2, d3) -> (d0, d1, d2)>
                     ]
ins (%arg0, %arg1: memref<3x5xf32>, memref<2x7x5xf32>)
                     outs (%arg2: memref<2x3x7xf32>)
    ```

RFCs and related PR:

https://discourse.llvm.org/t/rfc-linalg-opdsl-constant-list-attribute-definition/80149
https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863
https://discourse.llvm.org/t/rfc-mlir-linalg-operation-tree/83586
llvm#115319
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants