-
Notifications
You must be signed in to change notification settings - Fork 12.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][Linalg] Introduce broadcast/transpose semantic to 'linalg.batc… #122275
Conversation
@llvm/pr-subscribers-mlir-linalg Author: Md Asghar Ahmad Shahid (shahidact) Changes…h_matmul' operation. Goals:
Scope of this patch: The broadcast and transpose semantic are as follows: By default, 'linalg.batch_matmul' behavior will remain as is. Broadcast and Transpose semantics can be applied by specifying the explicit attribute 'indexing_maps' as shown below. This is a list attribute, so the list must include all the maps if specified.
Patch is 35.78 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/122275.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index b0ea1f76955816..496a323249e852 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1472,75 +1472,6 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: rhs
--- !LinalgOpConfig
-metadata: !LinalgOpMetadata
- name: batch_matmul
- cpp_class_name: BatchMatmulOp
- doc: |-
- Performs a batched matrix multiplication of two 3D inputs.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- implements:
- - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: A
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
- - !LinalgOperandDefConfig
- name: B
- kind: input_tensor
- type_var: T2
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
- - !LinalgOperandDefConfig
- name: C
- kind: output_tensor
- type_var: U
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
- iterator_types:
- - parallel
- - parallel
- - parallel
- - reduction
- assignments:
- - !ScalarAssign
- arg: C
- value: !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: add
- operands:
- - !ScalarExpression
- scalar_arg: C
- - !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: mul
- operands:
- - !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: A
- - !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: B
---- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: batch_matmul_transpose_a
cpp_class_name: BatchMatmulTransposeAOp
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index fff4048ee125e0..47b871aa322309 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -680,6 +680,130 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
}];
}
+//===----------------------------------------------------------------------===//
+// Op definition for BatchMatmulOp
+//===----------------------------------------------------------------------===//
+
+def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSizedOperandSegments],
+ /*extraInterfaces=*/[LinalgContractionOpInterface])> {
+
+ let summary = [{Performs a batched matrix multiplication of two 3D inputs.}];
+ let description = [{Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+
+ Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
+ 'indexing_maps' as shown below.This is a list attribute, so the list must include all
+ the maps if specified.
+
+ Example Transpose:
+ ```
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
+ outs(%arg2: memref<2x3x7xf32>)
+ ```
+
+ Example Broadcast:
+ ```
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
+ outs(%arg2: memref<2x3x7xf32>)
+ ```
+
+ Example Broadcast and transpose:
+ ```
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
+ outs(%arg2: memref<2x3x7xf32>)
+ ```
+}];
+
+ let arguments = (ins
+ Variadic<AnyType>:$inputs,
+ Variadic<AnyShaped>:$outputs,
+ DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
+ );
+ let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+ let regions = (region AnyRegion:$region);
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<
+ (ins "ValueRange":$inputs, "ValueRange":$outputs,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ buildBatchMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
+ attributes, BatchMatmulOp::getRegionBuilder(),
+ BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+ "ValueRange":$outputs,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ buildBatchMatmulOp($_builder, $_state, resultTensorTypes,
+ inputs, outputs, attributes, BatchMatmulOp::getRegionBuilder(),
+ BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addOperands(operands);
+ $_state.addAttributes(attributes);
+ $_state.addTypes(resultTensorTypes);
+ (void)$_state.addRegion(),
+ BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext());
+ }]>
+
+ ];
+ let hasCustomAssemblyFormat = 1;
+ let hasFolder = 1;
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
+
+ SmallVector<utils::IteratorType> getIteratorTypesArray();
+ static void regionBuilder(ImplicitLocOpBuilder &b,
+ Block &block, ArrayRef<NamedAttribute> attrs);
+ static std::function<void(ImplicitLocOpBuilder &,
+ Block &, ArrayRef<NamedAttribute>)>
+ getRegionBuilder() {
+ return regionBuilder;
+ }
+
+ /// Returns a list of AffineMap with the typical batch_matmul indexing charactristic.
+ static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
+
+ /// Returns true if the given broadcast map \p bcastMap is valid for this op.
+ bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true);
+
+ ::mlir::MutableOperandRange getDpsInitsMutable() {
+ return getOutputsMutable();
+ }
+
+ // Generic methods.
+ static unsigned getNumRegionArgs();
+ bool hasDynamicIndexingMaps() { return true; }
+ std::string getLibraryCallName();
+ /// Check if the op has broadcast and/or transpose semantic. Returns true if the
+ /// user defined indexing maps are not equal to default map.
+ bool hasUserDefinedMaps();
+ }];
+}
+
+
//===----------------------------------------------------------------------===//
// Named Linalg ops, implemented as a declarative configurations of generic ops.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 8973e87c063b33..868892d1e5f5cc 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -203,6 +203,23 @@ static void buildMatmulOp(OpBuilder &b, OperationState &state,
attributes, regionBuilder);
}
+static void buildBatchMatmulOp(OpBuilder &b, OperationState &state,
+ std::optional<TypeRange> resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes,
+ RegionBuilderFn regionBuilder,
+ ArrayRef<AffineMap> indexingMaps) {
+ // Initialize indexingMaps attribute, for BatchMatmulOp.
+ SmallVector<Attribute, 4> indexingMapsAttrVal;
+ indexingMapsAttrVal =
+ llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
+ return AffineMapAttr::get(map);
+ });
+ state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+ return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
+ attributes, regionBuilder);
+}
+
/// Common parsing used for both named structured ops created by ods-gen and by
/// manually defined C++ ops. Does not handle regions.
static ParseResult
@@ -3450,6 +3467,46 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
return success();
}
+/// Checks if the given AffineMap represents a valid batch dimension.
+/// It checks if the first result dimension is a function of the first
+/// dimension.
+static bool isValidBatchDim(AffineMap bcastMap) {
+ assert(bcastMap.getNumResults() == 3 && "Expected three result dim expr.");
+ AffineExpr exp = bcastMap.getResult(0);
+ return exp.isFunctionOfDim(0);
+}
+
+/// Verifies the broadcast and transpose semantic sepecified by the explicit
+/// indexing map for the BatchMatmulOp \p op for each operand specified by \p
+/// opIndex.
+static LogicalResult
+verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp,
+ unsigned opIndex) {
+ SmallVector<AffineMap, 3> opIndexingMaps =
+ batchMatmulOp.getIndexingMapsArray();
+ SmallVector<AffineMap, 3> defaultIndexingMaps =
+ batchMatmulOp.getDefaultIndexingMaps(batchMatmulOp->getContext());
+
+ auto opIndexingMap = opIndexingMaps[opIndex];
+ auto defaultIndexingMap = defaultIndexingMaps[opIndex];
+ // Check general validity of indexing map results.
+ if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap))
+ return batchMatmulOp->emitOpError()
+ << "Unexpected dim expression in map result.";
+ // Check if the requested broadcast is valid.
+ if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
+ if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, opIndex == 0)) {
+ return batchMatmulOp->emitOpError() << "Invalid broadcast requested.";
+ }
+ } else {
+ if (!isValidBatchDim(opIndexingMap)) {
+ return batchMatmulOp->emitOpError()
+ << "Invalid batch dimension expression.";
+ }
+ }
+ return success();
+}
+
namespace mlir {
namespace linalg {
@@ -3611,5 +3668,165 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
}
+//===----------------------------------------------------------------------===//
+// Implementation of BatchMatmulOp
+//===----------------------------------------------------------------------===//
+
+SmallVector<AffineMap>
+BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
+ AffineExpr d0, d1, d2, d3;
+ SmallVector<AffineMap> indexingMaps;
+ bindDims(context, d0, d1, d2, d3);
+ indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
+ indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
+ indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d2}, context));
+ return indexingMaps;
+}
+
+SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
+ return SmallVector<utils::IteratorType>{
+ utils::IteratorType::parallel, utils::IteratorType::parallel,
+ utils::IteratorType::parallel, utils::IteratorType::reduction};
+}
+
+unsigned BatchMatmulOp::getNumRegionArgs() { return 3; }
+
+std::string BatchMatmulOp::getLibraryCallName() {
+ return generateLibraryCallName(getOperation());
+}
+
+/// Check if the op has broadcast and/or transpose semantic. Returns true if
+/// the user defined indexing maps are not equal to default map.
+bool BatchMatmulOp::hasUserDefinedMaps() {
+ SmallVector<AffineMap, 3> defaultMaps =
+ getDefaultIndexingMaps(this->getContext());
+ SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
+ return defaultMaps != explicitMaps;
+}
+
+/// Returns true if the given broadcast map \p bcastMap is valid for this op.
+bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
+ assert(bcastMap.getNumResults() < 3 && "Expected single result dim expr.");
+ bool isValid = false;
+ enum Indices { batchPos, mPos, nPos, kPos };
+ if (bcastMap.getNumResults() == 1) {
+ AffineExpr exp = bcastMap.getResult(0);
+ isValid = exp.isFunctionOfDim(kPos);
+ } else if (bcastMap.getNumResults() == 2) {
+ AffineExpr exp0 = bcastMap.getResult(0);
+ AffineExpr exp1 = bcastMap.getResult(1);
+ isValid = isLHS
+ ? (exp0.isFunctionOfDim(mPos) && exp1.isFunctionOfDim(kPos))
+ : (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos));
+ }
+ return isValid;
+}
+
+void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs) {
+ assert(3 > 0 && block.getNumArguments() == 3 &&
+ "BatchMatmulOp regionBuilder expects 3 (>=0) args");
+ RegionBuilderHelper helper(b, block);
+ SmallVector<Value> yields;
+
+ Value value1 =
+ helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+ block.getArgument(0));
+ Value value2 =
+ helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+ block.getArgument(1));
+ Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
+ Value value4 =
+ helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
+ yields.push_back(value4);
+ helper.yieldOutputs(yields);
+}
+
+ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
+ SmallVector<Attribute, 3> indexingMapsAttr;
+ Attribute mapAttr;
+ if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+ if (parser.parseEqual())
+ return failure();
+
+ if (parser.parseLSquare())
+ return failure();
+
+ do {
+ if (parser.parseAttribute(mapAttr))
+ return failure();
+ if (!isa<AffineMapAttr>(mapAttr)) {
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected affine map attribute");
+ }
+ indexingMapsAttr.push_back(mapAttr);
+
+ if (parser.parseOptionalComma())
+ break;
+ } while (true);
+
+ if (parser.parseRSquare())
+ return failure();
+ }
+ // Initialize indexingMaps, if not supplied explicitly.
+ if (indexingMapsAttr.empty()) {
+ indexingMapsAttr = llvm::map_to_vector(
+ BatchMatmulOp::getDefaultIndexingMaps(parser.getContext()),
+ [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+ }
+ result.addAttribute("indexing_maps",
+ parser.getBuilder().getArrayAttr(indexingMapsAttr));
+
+ return ::parseNamedStructuredOp(parser, result,
+ BatchMatmulOp::getNumRegionArgs(),
+ BatchMatmulOp::getRegionBuilder());
+}
+
+void BatchMatmulOp::print(OpAsmPrinter &p) {
+ SmallVector<StringRef, 3> elidedAttrs = {
+ "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
+ ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+ elidedAttrs);
+
+ SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
+ BatchMatmulOp::getDefaultIndexingMaps(getContext()),
+ [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+ if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
+ p << " indexing_maps = [";
+ llvm::interleaveComma(getIndexingMaps(), p,
+ [&](Attribute attr) { p.printAttribute(attr); });
+ p << "]";
+ }
+}
+
+/// Verify the user defined indexing maps.
+LogicalResult BatchMatmulOp::verify() {
+ // Verification of pure batch_matmul is handled by
+ // verifyStructuredOpInterface().
+ if (!hasUserDefinedMaps())
+ return success();
+
+ for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
+ if (failed(verifyExtendedBatchMatmulSemantic(*this, opIndex)))
+ return failure();
+ }
+ return success();
+}
+
+LogicalResult BatchMatmulOp::fold(FoldAdaptor,
+ SmallVectorImpl<OpFoldResult> &) {
+ return memref::foldMemRefCast(*this);
+}
+void BatchMatmulOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ if (hasPureTensorSemantics())
+ return;
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+Speculation::Speculatability BatchMatmulOp::getSpeculatability() {
+ return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 9b97865990bfdd..a5d4c7fe9908c5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -935,7 +935,8 @@ struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs},
ValueRange{collapsedInit});
for (auto attr : contractionOp->getAttrs()) {
- if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
+ if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName ||
+ attr.getName() == "indexing_maps")
continue;
collapsedOp->setAttr(attr.getName(), attr.getValue());
}
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index c95cd5eecfffca..040663c882a086 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -484,24 +484,6 @@ def batch_mmt4d(
) * TypeFn.cast_signed(TV.AccumType, rhs[D.b, D.n, D.k, D.n0, D.k0])
-@linalg_structured_op
-def batch_matmul(
- A=TensorDef(T1, Batch, S.M, S.K),
- B=TensorDef(T2, Batch, S.K, S.N),
- C=TensorDef(U, Batch, S.M, S.N, output=True),
-):
- """Performs a batched matrix multiplication of two 3D inputs.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- domain(D.b, D.m, D.n, D.k)
- implements(ContractionOpInterface)
- C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
- U, B[D.b, D.k, D.n]
- )
-
-
@linalg_structured_op
def batch_matmul_transpose_a(
A=TensorDef(T1, Batch, S.K, S.M),
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index aba26c35931fd3..638238b5c38a60 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -1002,3 +1002,26 @@ func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7
// -----
+// CHECK: #[[$ATTR_...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Md Asghar Ahmad Shahid (shahidact) Changes…h_matmul' operation. Goals:
Scope of this patch: The broadcast and transpose semantic are as follows: By default, 'linalg.batch_matmul' behavior will remain as is. Broadcast and Transpose semantics can be applied by specifying the explicit attribute 'indexing_maps' as shown below. This is a list attribute, so the list must include all the maps if specified.
Patch is 35.78 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/122275.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index b0ea1f76955816..496a323249e852 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1472,75 +1472,6 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: rhs
--- !LinalgOpConfig
-metadata: !LinalgOpMetadata
- name: batch_matmul
- cpp_class_name: BatchMatmulOp
- doc: |-
- Performs a batched matrix multiplication of two 3D inputs.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- implements:
- - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: A
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
- - !LinalgOperandDefConfig
- name: B
- kind: input_tensor
- type_var: T2
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
- - !LinalgOperandDefConfig
- name: C
- kind: output_tensor
- type_var: U
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
- iterator_types:
- - parallel
- - parallel
- - parallel
- - reduction
- assignments:
- - !ScalarAssign
- arg: C
- value: !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: add
- operands:
- - !ScalarExpression
- scalar_arg: C
- - !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: mul
- operands:
- - !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: A
- - !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: B
---- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: batch_matmul_transpose_a
cpp_class_name: BatchMatmulTransposeAOp
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index fff4048ee125e0..47b871aa322309 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -680,6 +680,130 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
}];
}
+//===----------------------------------------------------------------------===//
+// Op definition for BatchMatmulOp
+//===----------------------------------------------------------------------===//
+
+def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSizedOperandSegments],
+ /*extraInterfaces=*/[LinalgContractionOpInterface])> {
+
+ let summary = [{Performs a batched matrix multiplication of two 3D inputs.}];
+ let description = [{Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+
+ Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
+ 'indexing_maps' as shown below.This is a list attribute, so the list must include all
+ the maps if specified.
+
+ Example Transpose:
+ ```
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
+ outs(%arg2: memref<2x3x7xf32>)
+ ```
+
+ Example Broadcast:
+ ```
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
+ outs(%arg2: memref<2x3x7xf32>)
+ ```
+
+ Example Broadcast and transpose:
+ ```
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
+ outs(%arg2: memref<2x3x7xf32>)
+ ```
+}];
+
+ let arguments = (ins
+ Variadic<AnyType>:$inputs,
+ Variadic<AnyShaped>:$outputs,
+ DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
+ );
+ let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+ let regions = (region AnyRegion:$region);
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<
+ (ins "ValueRange":$inputs, "ValueRange":$outputs,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ buildBatchMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
+ attributes, BatchMatmulOp::getRegionBuilder(),
+ BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+ "ValueRange":$outputs,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ buildBatchMatmulOp($_builder, $_state, resultTensorTypes,
+ inputs, outputs, attributes, BatchMatmulOp::getRegionBuilder(),
+ BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addOperands(operands);
+ $_state.addAttributes(attributes);
+ $_state.addTypes(resultTensorTypes);
+ (void)$_state.addRegion(),
+ BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext());
+ }]>
+
+ ];
+ let hasCustomAssemblyFormat = 1;
+ let hasFolder = 1;
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
+
+ SmallVector<utils::IteratorType> getIteratorTypesArray();
+ static void regionBuilder(ImplicitLocOpBuilder &b,
+ Block &block, ArrayRef<NamedAttribute> attrs);
+ static std::function<void(ImplicitLocOpBuilder &,
+ Block &, ArrayRef<NamedAttribute>)>
+ getRegionBuilder() {
+ return regionBuilder;
+ }
+
+ /// Returns a list of AffineMap with the typical batch_matmul indexing charactristic.
+ static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
+
+ /// Returns true if the given broadcast map \p bcastMap is valid for this op.
+ bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true);
+
+ ::mlir::MutableOperandRange getDpsInitsMutable() {
+ return getOutputsMutable();
+ }
+
+ // Generic methods.
+ static unsigned getNumRegionArgs();
+ bool hasDynamicIndexingMaps() { return true; }
+ std::string getLibraryCallName();
+ /// Check if the op has broadcast and/or transpose semantic. Returns true if the
+ /// user defined indexing maps are not equal to default map.
+ bool hasUserDefinedMaps();
+ }];
+}
+
+
//===----------------------------------------------------------------------===//
// Named Linalg ops, implemented as a declarative configurations of generic ops.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 8973e87c063b33..868892d1e5f5cc 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -203,6 +203,23 @@ static void buildMatmulOp(OpBuilder &b, OperationState &state,
attributes, regionBuilder);
}
+static void buildBatchMatmulOp(OpBuilder &b, OperationState &state,
+ std::optional<TypeRange> resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes,
+ RegionBuilderFn regionBuilder,
+ ArrayRef<AffineMap> indexingMaps) {
+ // Initialize indexingMaps attribute, for BatchMatmulOp.
+ SmallVector<Attribute, 4> indexingMapsAttrVal;
+ indexingMapsAttrVal =
+ llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
+ return AffineMapAttr::get(map);
+ });
+ state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+ return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
+ attributes, regionBuilder);
+}
+
/// Common parsing used for both named structured ops created by ods-gen and by
/// manually defined C++ ops. Does not handle regions.
static ParseResult
@@ -3450,6 +3467,46 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
return success();
}
+/// Checks if the given AffineMap represents a valid batch dimension.
+/// It checks if the first result dimension is a function of the first
+/// dimension.
+static bool isValidBatchDim(AffineMap bcastMap) {
+ assert(bcastMap.getNumResults() == 3 && "Expected three result dim expr.");
+ AffineExpr exp = bcastMap.getResult(0);
+ return exp.isFunctionOfDim(0);
+}
+
+/// Verifies the broadcast and transpose semantic sepecified by the explicit
+/// indexing map for the BatchMatmulOp \p op for each operand specified by \p
+/// opIndex.
+static LogicalResult
+verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp,
+ unsigned opIndex) {
+ SmallVector<AffineMap, 3> opIndexingMaps =
+ batchMatmulOp.getIndexingMapsArray();
+ SmallVector<AffineMap, 3> defaultIndexingMaps =
+ batchMatmulOp.getDefaultIndexingMaps(batchMatmulOp->getContext());
+
+ auto opIndexingMap = opIndexingMaps[opIndex];
+ auto defaultIndexingMap = defaultIndexingMaps[opIndex];
+ // Check general validity of indexing map results.
+ if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap))
+ return batchMatmulOp->emitOpError()
+ << "Unexpected dim expression in map result.";
+ // Check if the requested broadcast is valid.
+ if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
+ if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, opIndex == 0)) {
+ return batchMatmulOp->emitOpError() << "Invalid broadcast requested.";
+ }
+ } else {
+ if (!isValidBatchDim(opIndexingMap)) {
+ return batchMatmulOp->emitOpError()
+ << "Invalid batch dimension expression.";
+ }
+ }
+ return success();
+}
+
namespace mlir {
namespace linalg {
@@ -3611,5 +3668,165 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
}
+//===----------------------------------------------------------------------===//
+// Implementation of BatchMatmulOp
+//===----------------------------------------------------------------------===//
+
+SmallVector<AffineMap>
+BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
+ AffineExpr d0, d1, d2, d3;
+ SmallVector<AffineMap> indexingMaps;
+ bindDims(context, d0, d1, d2, d3);
+ indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
+ indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
+ indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d2}, context));
+ return indexingMaps;
+}
+
+SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
+ return SmallVector<utils::IteratorType>{
+ utils::IteratorType::parallel, utils::IteratorType::parallel,
+ utils::IteratorType::parallel, utils::IteratorType::reduction};
+}
+
+unsigned BatchMatmulOp::getNumRegionArgs() { return 3; }
+
+std::string BatchMatmulOp::getLibraryCallName() {
+ return generateLibraryCallName(getOperation());
+}
+
+/// Check if the op has broadcast and/or transpose semantic. Returns true if
+/// the user defined indexing maps are not equal to default map.
+bool BatchMatmulOp::hasUserDefinedMaps() {
+ SmallVector<AffineMap, 3> defaultMaps =
+ getDefaultIndexingMaps(this->getContext());
+ SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
+ return defaultMaps != explicitMaps;
+}
+
+/// Returns true if the given broadcast map \p bcastMap is valid for this op.
+bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
+ assert(bcastMap.getNumResults() < 3 && "Expected single result dim expr.");
+ bool isValid = false;
+ enum Indices { batchPos, mPos, nPos, kPos };
+ if (bcastMap.getNumResults() == 1) {
+ AffineExpr exp = bcastMap.getResult(0);
+ isValid = exp.isFunctionOfDim(kPos);
+ } else if (bcastMap.getNumResults() == 2) {
+ AffineExpr exp0 = bcastMap.getResult(0);
+ AffineExpr exp1 = bcastMap.getResult(1);
+ isValid = isLHS
+ ? (exp0.isFunctionOfDim(mPos) && exp1.isFunctionOfDim(kPos))
+ : (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos));
+ }
+ return isValid;
+}
+
+void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs) {
+ assert(3 > 0 && block.getNumArguments() == 3 &&
+ "BatchMatmulOp regionBuilder expects 3 (>=0) args");
+ RegionBuilderHelper helper(b, block);
+ SmallVector<Value> yields;
+
+ Value value1 =
+ helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+ block.getArgument(0));
+ Value value2 =
+ helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+ block.getArgument(1));
+ Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
+ Value value4 =
+ helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
+ yields.push_back(value4);
+ helper.yieldOutputs(yields);
+}
+
+ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
+ SmallVector<Attribute, 3> indexingMapsAttr;
+ Attribute mapAttr;
+ if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+ if (parser.parseEqual())
+ return failure();
+
+ if (parser.parseLSquare())
+ return failure();
+
+ do {
+ if (parser.parseAttribute(mapAttr))
+ return failure();
+ if (!isa<AffineMapAttr>(mapAttr)) {
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected affine map attribute");
+ }
+ indexingMapsAttr.push_back(mapAttr);
+
+ if (parser.parseOptionalComma())
+ break;
+ } while (true);
+
+ if (parser.parseRSquare())
+ return failure();
+ }
+ // Initialize indexingMaps, if not supplied explicitly.
+ if (indexingMapsAttr.empty()) {
+ indexingMapsAttr = llvm::map_to_vector(
+ BatchMatmulOp::getDefaultIndexingMaps(parser.getContext()),
+ [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+ }
+ result.addAttribute("indexing_maps",
+ parser.getBuilder().getArrayAttr(indexingMapsAttr));
+
+ return ::parseNamedStructuredOp(parser, result,
+ BatchMatmulOp::getNumRegionArgs(),
+ BatchMatmulOp::getRegionBuilder());
+}
+
+void BatchMatmulOp::print(OpAsmPrinter &p) {
+ SmallVector<StringRef, 3> elidedAttrs = {
+ "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
+ ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+ elidedAttrs);
+
+ SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
+ BatchMatmulOp::getDefaultIndexingMaps(getContext()),
+ [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+ if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
+ p << " indexing_maps = [";
+ llvm::interleaveComma(getIndexingMaps(), p,
+ [&](Attribute attr) { p.printAttribute(attr); });
+ p << "]";
+ }
+}
+
+/// Verify the user defined indexing maps.
+LogicalResult BatchMatmulOp::verify() {
+ // Verification of pure batch_matmul is handled by
+ // verifyStructuredOpInterface().
+ if (!hasUserDefinedMaps())
+ return success();
+
+ for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
+ if (failed(verifyExtendedBatchMatmulSemantic(*this, opIndex)))
+ return failure();
+ }
+ return success();
+}
+
+LogicalResult BatchMatmulOp::fold(FoldAdaptor,
+ SmallVectorImpl<OpFoldResult> &) {
+ return memref::foldMemRefCast(*this);
+}
+void BatchMatmulOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ if (hasPureTensorSemantics())
+ return;
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+Speculation::Speculatability BatchMatmulOp::getSpeculatability() {
+ return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 9b97865990bfdd..a5d4c7fe9908c5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -935,7 +935,8 @@ struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs},
ValueRange{collapsedInit});
for (auto attr : contractionOp->getAttrs()) {
- if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
+ if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName ||
+ attr.getName() == "indexing_maps")
continue;
collapsedOp->setAttr(attr.getName(), attr.getValue());
}
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index c95cd5eecfffca..040663c882a086 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -484,24 +484,6 @@ def batch_mmt4d(
) * TypeFn.cast_signed(TV.AccumType, rhs[D.b, D.n, D.k, D.n0, D.k0])
-@linalg_structured_op
-def batch_matmul(
- A=TensorDef(T1, Batch, S.M, S.K),
- B=TensorDef(T2, Batch, S.K, S.N),
- C=TensorDef(U, Batch, S.M, S.N, output=True),
-):
- """Performs a batched matrix multiplication of two 3D inputs.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- domain(D.b, D.m, D.n, D.k)
- implements(ContractionOpInterface)
- C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
- U, B[D.b, D.k, D.n]
- )
-
-
@linalg_structured_op
def batch_matmul_transpose_a(
A=TensorDef(T1, Batch, S.K, S.M),
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index aba26c35931fd3..638238b5c38a60 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -1002,3 +1002,26 @@ func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7
// -----
+// CHECK: #[[$ATTR_...
[truncated]
|
Please, could you provide references to relevant RFCs where this was discussed/agreed-on? And include in the summary. Thanks :) |
Added :) |
Havent reviewed in detail yet. How is this related to https://discourse.llvm.org/t/mlir-rfc-introduce-linalg-contract/83589 . |
It is not. This is just a cleanup of OpDSL that we agreed last year. Please do not cross the wires here. |
Thanks! @MaheshRavishankar , relevant links have been added in the summary at the top:
Looks like the syntax introduced here matches #115319 (unless I missed sth), so this is consistent with the previous patches in the series 👍🏻 One high level (kind) request:
Why not split this into two patches? Those two goals seem orthogonal to me and reviewing small patches is easier (at least on GitHub). |
From discussions on the previous PR (matmul), there should be no differences to what we did there. The ODS implementation is identical to the OpDSL one except the affine map. All previous tests should pass. So the extra logic that needs to be reviewed is actually a very small portion of the whole patch. If we split, the first patch would be almost as big as this one. But getting the semantics correct is not trivial either. So splitting could lead to an interim state that is not quite here nor there, and we want to avoid that. |
That's the idea. After this, there's a last one for |
Thanks @shahidact looks good to me (i didnt review all the details, but mostly for structure of things and that matches what I expect) |
Thanks a lot. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this, the overall structure is good!
I've skimmed through the easier part and left some minor comments. I'll try to go over the rest in the next day/two.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fixes 👍
LGTM % open conversations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I've reviewed so far LGTM % the remaining comments (please address before landing). I’m relying on @adam-smnk for the rest 😅
I will be OOO for a week, so approving as is to unblock you. Please feel free to merge once all outstanding comments are resolved.
Last but not least, thank you for working on this, @shahidact - this is a non-trivial area, and your work will have a significant impact on Linalg 🙏🏻
// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] | ||
// CHECK: return | ||
// CHECK: } | ||
func.func @batch_matmul_bcast_batch_dim_A(%arg0: memref<3x5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that you have updated @batch_matmul_bcast_k_to_fill_missing_dims_A
, but forgot to use similar naming style for other tests. Did you mean @batch_matmul_bcast_m_to_fill_missing_batch_dim_A
? Same for other tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you mean
@batch_matmul_bcast_m_to_fill_missing_batch_dim_A
?
I meant to broadcast row-column of matrix A across broadcast dim.
…h_matmul' operation. Goals: 1. To add syntax and semantic to 'batch_matmul' without changing any of the existing syntax expectations for current usage. batch_matmul is still just batch_matmul. 2. Move the definition of batch_matmul from linalg OpDsl to tablegen ODS infra. Scope of this patch: To expose broadcast and transpose semantics on the 'batch_matmul'. The broadcast and transpose semantic is as follows: By default 'linalg.batch_matmul' behavior will remain as is. Broadcast and Transpose semantics can be appiled by specifying the explicit attribute 'indexing_maps' as shown below.This is a list attribute, so the list must include all the maps if specified. Example Transpose: ``` linalg.batch_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, //transpose affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> ] ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>) outs(%arg2: memref<2x3x7xf32>) ``` Example Broadcast: ``` linalg.batch_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d3)>, //broadcast affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> ] ins(%arg0, %arg1 : memref<5xf32>,memref<2x5x7xf32>) outs(%arg2: memref<2x3x7xf32>) ``` Example Broadcast and transpose: ``` linalg.batch_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d1, d3)>, //broadcast affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, //transpose affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> ] ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<2x3x7xf32>) ```
-Replaced assert for the count of number of dim expression with proper error reporting and new test case. -Fixed typos.
*Added and udated test cases. *Refactored verification logic.
…lOp. *Updated test names and comments for consistency.
…collapsed contraction op. *Refactored some tests and methods for better naming, comments and readability.
…ontraction Op having user defined indexing_maps.
14b440b
to
048a481
Compare
…llvm#122275) Goals: 1. To add syntax and semantic to 'batch_matmul' without changing any of the existing syntax expectations for current usage. batch_matmul is still just batch_matmul. 2. Move the definition of batch_matmul from linalg OpDsl to tablegen ODS infra. Scope of this patch: To expose broadcast and transpose semantics on the 'batch_matmul'. The broadcast and transpose semantic are as follows: By default, 'linalg.batch_matmul' behavior will remain as is. Broadcast and Transpose semantics can be applied by specifying the explicit attribute 'indexing_maps' as shown below. This is a list attribute, so the list must include all the maps if specified. Example Transpose: ``` linalg.batch_matmul indexing_maps = [ affine_map< (d0, d1, d2, d3) -> (d0, d3, d1)>, //transpose affine_map< (d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map< (d0, d1, d2, d3) -> (d0, d1, d2)> ] ins (%arg0, %arg1: memref<2x5x3xf32>,memref<2x5x7xf32>) outs (%arg2: memref<2x3x7xf32>) ``` Example Broadcast: ``` linalg.batch_matmul indexing_maps = [ affine_map< (d0, d1, d2, d3) -> (d3)>, //broadcast affine_map< (d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map< (d0, d1, d2, d3) -> (d0, d1, d2)> ] ins (%arg0, %arg1: memref<5xf32>,memref<2x5x7xf32>) outs (%arg2: memref<2x3x7xf32>) ``` Example Broadcast and transpose: ``` linalg.batch_matmul indexing_maps = [ affine_map< (d0, d1, d2, d3) -> (d1, d3)>, //broadcast affine_map< (d0, d1, d2, d3) -> (d0, d2, d3)>, //transpose affine_map< (d0, d1, d2, d3) -> (d0, d1, d2)> ] ins (%arg0, %arg1: memref<3x5xf32>, memref<2x7x5xf32>) outs (%arg2: memref<2x3x7xf32>) ``` RFCs and related PR: https://discourse.llvm.org/t/rfc-linalg-opdsl-constant-list-attribute-definition/80149 https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863 https://discourse.llvm.org/t/rfc-mlir-linalg-operation-tree/83586 llvm#115319
…h_matmul' operation.
Goals:
To add syntax and semantic to 'batch_matmul' without changing any of the existing syntax expectations for current usage. batch_matmul is still just batch_matmul.
Move the definition of batch_matmul from linalg OpDsl to tablegen ODS infra.
Scope of this patch:
To expose broadcast and transpose semantics on the 'batch_matmul'.
The broadcast and transpose semantic are as follows:
By default, 'linalg.batch_matmul' behavior will remain as is. Broadcast and Transpose semantics can be applied by specifying the explicit attribute 'indexing_maps' as shown below. This is a list attribute, so the list must include all the maps if specified.
RFCs and related PR:
https://discourse.llvm.org/t/rfc-linalg-opdsl-constant-list-attribute-definition/80149
https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863
https://discourse.llvm.org/t/rfc-mlir-linalg-operation-tree/83586
#115319