-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Linalg] NFC: Expose a method to deduplicate operands/remove dead results of linalg.generic
op.
#125141
base: main
Are you sure you want to change the base?
[mlir][Linalg] NFC: Expose a method to deduplicate operands/remove dead results of linalg.generic
op.
#125141
Conversation
…ad results of `linalg.generic` op. This functionality was wrapped within a pattern. Expose this as a separate transformations function that can be used outside of pattern rewrite mechanism. Signed-off-by: MaheshRavishankar <[email protected]>
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: None (MaheshRavishankar) ChangesThis functionality was wrapped within a pattern. Expose this as a separate transformations function that can be used outside of pattern rewrite mechanism. Patch is 24.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125141.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index eed279b6be34ac..8e978413cdea0f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1785,8 +1785,16 @@ void populateDataLayoutPropagationPatterns(
RewritePatternSet &patterns,
const ControlPropagationFn &controlPackUnPackPropagation);
+/// Method to deduplicate operands and remove dead results of `linalg.generic`
+/// operations. This is effectively DCE for a linalg.generic op. If there is
+/// deduplication of operands orremoval of results, replaces the `genericOp`
+/// with a new op and returns it. Returns the same operation if there is no
+/// deduplication/removal.
+FailureOr<linalg::GenericOp> deduplicateOperandsAndRemoveDeadResults(
+ RewriterBase &rewriter, linalg::GenericOp genericOp, bool removeOutputs);
+
/// Pattern to remove dead operands and results of `linalg.generic` operations.
-/// This is effectively DCE for a linalg op.
+/// This is a pattern wrapper for `deduplicateOperandsAndRemoveDeadResults`.
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns);
/// Patterns to promote inputs to outputs and remove unused inputs of
diff --git a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp
index 16ab45ea8bee63..8b05f21bb08d16 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp
@@ -52,255 +52,266 @@ static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result) {
return true;
}
-namespace {
-
-struct DeduplicateAndRemoveDeadOperandsAndResults
- : public OpRewritePattern<GenericOp> {
- DeduplicateAndRemoveDeadOperandsAndResults(MLIRContext *ctx,
- bool removeOutputs)
- : OpRewritePattern<GenericOp>(ctx), removeOutputs(removeOutputs) {}
-
- LogicalResult matchAndRewrite(GenericOp genericOp,
- PatternRewriter &rewriter) const override {
- // Create a map from argument position in the original op to the argument
- // position in the new op. If the argument is dropped it wont have an entry.
- SmallVector<OpOperand *> droppedOpOperands;
-
- // Information needed to build the new op.
- SmallVector<Value> newInputOperands, newOutputOperands;
- SmallVector<AffineMap> newIndexingMaps;
-
- // Gather information about duplicate input operands.
- llvm::SmallDenseMap<unsigned, unsigned> origInsToNewInsPos =
- deduplicateInputOperands(genericOp, droppedOpOperands, newInputOperands,
- newIndexingMaps);
-
- // Gather information about the dropped outputs.
- llvm::SmallDenseMap<unsigned, unsigned> origOutsToNewOutsPos =
- deduplicateOutputOperands(genericOp, droppedOpOperands,
- newOutputOperands, newIndexingMaps);
-
- // Check if there is any change to operands.
- if (newInputOperands.size() + newOutputOperands.size() ==
- genericOp->getNumOperands())
- return failure();
-
- // Create the new op with the body being empty.
- Location loc = genericOp.getLoc();
- SmallVector<Type> newResultTypes;
- for (Value v : newOutputOperands)
- if (isa<TensorType>(v.getType()))
- newResultTypes.push_back(v.getType());
- auto newOp = rewriter.create<GenericOp>(
- loc, newResultTypes, newInputOperands, newOutputOperands,
- rewriter.getAffineMapArrayAttr(newIndexingMaps),
- genericOp.getIteratorTypes(), genericOp.getDocAttr(),
- genericOp.getLibraryCallAttr(),
- [](OpBuilder & /*builder*/, Location /*loc*/, ValueRange /*args*/) {
- return;
- });
- // Copy over unknown attributes. They might be load bearing for some flow.
- ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
- for (NamedAttribute kv : genericOp->getAttrs())
- if (!llvm::is_contained(odsAttrs, kv.getName().getValue()))
- newOp->setAttr(kv.getName(), kv.getValue());
-
- // Fix up the payload of the canonicalized operation.
- populateOpPayload(genericOp, newOp, origInsToNewInsPos,
- origOutsToNewOutsPos, rewriter);
-
- // Replace all live uses of the op.
- SmallVector<Value> replacementsVals(genericOp->getNumResults(), nullptr);
- for (const auto &result : llvm::enumerate(genericOp.getResults())) {
- auto it = origOutsToNewOutsPos.find(result.index());
- if (it == origOutsToNewOutsPos.end())
+//===---------------------------------------------------------------------===//
+// Helper methods for operand deduplication and dead results elimination
+//===---------------------------------------------------------------------===//
+
+// Deduplicate input operands, and return the
+// - Mapping from operand position in the original op, to operand position in
+// the canonicalized op.
+// - The preserved input operands list (by reference).
+llvm::SmallDenseMap<unsigned, unsigned> static deduplicateInputOperands(
+ GenericOp genericOp, SmallVector<OpOperand *> &droppedOpOperands,
+ SmallVector<Value> &newInputOperands,
+ SmallVector<AffineMap> &newIndexingMaps) {
+ llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
+ llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> dedupedInputs;
+ for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) {
+ OpOperand *inputOpOperand = en.value();
+ // Check if operand is dead and if dropping the indexing map makes the
+ // loops to shape computation invalid.
+ if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) {
+ // Add the current operands to the list of potentially droppable
+ // operands. If it cannot be dropped, this needs to be popped back.
+ droppedOpOperands.push_back(inputOpOperand);
+ if (genericOp.canOpOperandsBeDropped(droppedOpOperands))
continue;
- replacementsVals[result.index()] = newOp.getResult(it->second);
+ droppedOpOperands.pop_back();
}
- rewriter.replaceOp(genericOp, replacementsVals);
- return success();
- }
-private:
- /// If unset, outputs are not modified by this pattern.
- bool removeOutputs;
-
- // Deduplicate input operands, and return the
- // - Mapping from operand position in the original op, to operand position in
- // the canonicalized op.
- // - The preserved input operands list (by reference).
- llvm::SmallDenseMap<unsigned, unsigned>
- deduplicateInputOperands(GenericOp genericOp,
- SmallVector<OpOperand *> &droppedOpOperands,
- SmallVector<Value> &newInputOperands,
- SmallVector<AffineMap> &newIndexingMaps) const {
- llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
- llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> dedupedInputs;
- for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) {
- OpOperand *inputOpOperand = en.value();
- // Check if operand is dead and if dropping the indexing map makes the
- // loops to shape computation invalid.
- if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) {
- // Add the current operands to the list of potentially droppable
- // operands. If it cannot be dropped, this needs to be popped back.
- droppedOpOperands.push_back(inputOpOperand);
- if (genericOp.canOpOperandsBeDropped(droppedOpOperands))
- continue;
- droppedOpOperands.pop_back();
- }
+ // Check if this operand is a duplicate.
+ AffineMap indexingMap = genericOp.getMatchingIndexingMap(inputOpOperand);
+ auto it =
+ dedupedInputs.find(std::make_pair(inputOpOperand->get(), indexingMap));
+ if (it != dedupedInputs.end()) {
+ origToNewPos[en.index()] = it->second;
+ droppedOpOperands.push_back(inputOpOperand);
+ continue;
+ }
- // Check if this operand is a duplicate.
- AffineMap indexingMap = genericOp.getMatchingIndexingMap(inputOpOperand);
- auto it = dedupedInputs.find(
- std::make_pair(inputOpOperand->get(), indexingMap));
- if (it != dedupedInputs.end()) {
- origToNewPos[en.index()] = it->second;
- droppedOpOperands.push_back(inputOpOperand);
- continue;
- }
+ // This is a preserved argument.
+ origToNewPos[en.index()] = newInputOperands.size();
+ dedupedInputs[{inputOpOperand->get(), indexingMap}] =
+ newInputOperands.size();
+ newInputOperands.push_back(inputOpOperand->get());
+ newIndexingMaps.push_back(indexingMap);
+ }
+ return origToNewPos;
+}
- // This is a preserved argument.
- origToNewPos[en.index()] = newInputOperands.size();
- dedupedInputs[{inputOpOperand->get(), indexingMap}] =
- newInputOperands.size();
- newInputOperands.push_back(inputOpOperand->get());
- newIndexingMaps.push_back(indexingMap);
+// Deduplicate output operands, and return the
+// - Mapping from operand position in the original op, to operand position in
+// the canonicalized op.
+// - The preserved output operands list (by reference).
+llvm::SmallDenseMap<unsigned, unsigned> static deduplicateOutputOperands(
+ GenericOp genericOp, SmallVector<OpOperand *> &droppedOpOperands,
+ SmallVector<Value> &newOutputOperands,
+ SmallVector<AffineMap> &newIndexingMaps, bool removeOutputs) {
+ llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
+ llvm::SmallDenseMap<std::tuple<Value, AffineMap, Value>, unsigned>
+ dedupedOutpts;
+ // If the op doesn't have tensor semantics or outputs should not be removed,
+ // keep all the outputs as preserved.
+ if (!genericOp.hasPureTensorSemantics() || !removeOutputs) {
+ for (const auto &en : llvm::enumerate(genericOp.getDpsInitsMutable())) {
+ origToNewPos[en.index()] = newOutputOperands.size();
+ newOutputOperands.push_back(en.value().get());
+ newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(&en.value()));
}
return origToNewPos;
}
-
- // Deduplicate output operands, and return the
- // - Mapping from operand position in the original op, to operand position in
- // the canonicalized op.
- // - The preserved output operands list (by reference).
- llvm::SmallDenseMap<unsigned, unsigned>
- deduplicateOutputOperands(GenericOp genericOp,
- SmallVector<OpOperand *> &droppedOpOperands,
- SmallVector<Value> &newOutputOperands,
- SmallVector<AffineMap> &newIndexingMaps) const {
- llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
- llvm::SmallDenseMap<std::tuple<Value, AffineMap, Value>, unsigned>
- dedupedOutpts;
- // If the op doesn't have tensor semantics or outputs should not be removed,
- // keep all the outputs as preserved.
- if (!genericOp.hasPureTensorSemantics() || !removeOutputs) {
- for (const auto &en : llvm::enumerate(genericOp.getDpsInitsMutable())) {
- origToNewPos[en.index()] = newOutputOperands.size();
- newOutputOperands.push_back(en.value().get());
- newIndexingMaps.push_back(
- genericOp.getMatchingIndexingMap(&en.value()));
+ // Output argument can be dropped if the result has
+ // - no users, and
+ // - it is not used in the payload, and
+ // - the corresponding indexing maps are not needed for loop bound
+ // computation.
+ auto yieldOp = cast<YieldOp>(genericOp.getBody()->getTerminator());
+ for (const auto &outputOpOperand :
+ llvm::enumerate(genericOp.getDpsInitsMutable())) {
+ OpResult result = genericOp.getTiedOpResult(&outputOpOperand.value());
+ AffineMap indexingMap =
+ genericOp.getMatchingIndexingMap(&outputOpOperand.value());
+ auto key = std::make_tuple(outputOpOperand.value().get(), indexingMap,
+ yieldOp->getOperand(outputOpOperand.index()));
+ if (isResultValueDead(genericOp, result)) {
+ // Check if the opoperand can be dropped without affecting loop
+ // bound computation. Add the operand to the list of dropped op
+ // operand for checking. If it cannot be dropped, need to pop the
+ // value back.
+ droppedOpOperands.push_back(&outputOpOperand.value());
+ if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) {
+ continue;
}
- return origToNewPos;
+ droppedOpOperands.pop_back();
}
- // Output argument can be dropped if the result has
- // - no users, and
- // - it is not used in the payload, and
- // - the corresponding indexing maps are not needed for loop bound
- // computation.
- auto yieldOp = cast<YieldOp>(genericOp.getBody()->getTerminator());
- for (const auto &outputOpOperand :
- llvm::enumerate(genericOp.getDpsInitsMutable())) {
- OpResult result = genericOp.getTiedOpResult(&outputOpOperand.value());
- AffineMap indexingMap =
- genericOp.getMatchingIndexingMap(&outputOpOperand.value());
- auto key = std::make_tuple(outputOpOperand.value().get(), indexingMap,
- yieldOp->getOperand(outputOpOperand.index()));
- if (isResultValueDead(genericOp, result)) {
- // Check if the opoperand can be dropped without affecting loop
- // bound computation. Add the operand to the list of dropped op
- // operand for checking. If it cannot be dropped, need to pop the
- // value back.
+
+ if (!genericOp.payloadUsesValueFromOperand(&outputOpOperand.value())) {
+ // The out operand can also be dropped if it is computed redundantly
+ // by another result, the conditions for that are
+ // - The same operand is used as the out operand
+ // - The same indexing map is used
+ // - The same yield value is used.
+ auto it = dedupedOutpts.find(key);
+ if (it != dedupedOutpts.end()) {
+ origToNewPos[outputOpOperand.index()] = it->second;
droppedOpOperands.push_back(&outputOpOperand.value());
- if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) {
- continue;
- }
- droppedOpOperands.pop_back();
+ continue;
}
+ }
- if (!genericOp.payloadUsesValueFromOperand(&outputOpOperand.value())) {
- // The out operand can also be dropped if it is computed redundantly
- // by another result, the conditions for that are
- // - The same operand is used as the out operand
- // - The same indexing map is used
- // - The same yield value is used.
- auto it = dedupedOutpts.find(key);
- if (it != dedupedOutpts.end()) {
- origToNewPos[outputOpOperand.index()] = it->second;
- droppedOpOperands.push_back(&outputOpOperand.value());
- continue;
- }
- }
+ origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
+ dedupedOutpts[key] = newOutputOperands.size();
+ newOutputOperands.push_back(outputOpOperand.value().get());
+ newIndexingMaps.push_back(
+ genericOp.getMatchingIndexingMap(&outputOpOperand.value()));
+ }
+ return origToNewPos;
+}
- origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
- dedupedOutpts[key] = newOutputOperands.size();
- newOutputOperands.push_back(outputOpOperand.value().get());
- newIndexingMaps.push_back(
- genericOp.getMatchingIndexingMap(&outputOpOperand.value()));
+// Populate the body of the canonicalized operation.
+static void populateOpPayload(
+ GenericOp genericOp, GenericOp newOp,
+ const llvm::SmallDenseMap<unsigned, unsigned> &origInsToNewInsPos,
+ const llvm::SmallDenseMap<unsigned, unsigned> &origOutsToNewOutsPos,
+ RewriterBase &rewriter) {
+ // Merge the body of the original op with the new op.
+ Block *newOpBlock = &newOp.getRegion().front();
+ assert(newOpBlock->empty() && "expected new op to have an empty payload");
+ Block *origOpBlock = &genericOp.getRegion().front();
+ SmallVector<Value> replacements(origOpBlock->getNumArguments(), nullptr);
+
+ // Replace all arguments in the original op, with arguments from the
+ // canonicalized op.
+ auto updateReplacements =
+ [&](SmallVector<OpOperand *> &origOperands,
+ SmallVector<OpOperand *> &newOperands,
+ const llvm::SmallDenseMap<unsigned, unsigned> &map) {
+ for (const auto &origOperand : llvm::enumerate(origOperands)) {
+ auto it = map.find(origOperand.index());
+ if (it == map.end())
+ continue;
+ OpOperand *newOperand = newOperands[it->second];
+ replacements[origOperand.value()->getOperandNumber()] =
+ newOpBlock->getArgument(newOperand->getOperandNumber());
+ }
+ };
+
+ SmallVector<OpOperand *> origInputOperands = genericOp.getDpsInputOperands();
+ SmallVector<OpOperand *> newInputOperands = newOp.getDpsInputOperands();
+ updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos);
+
+ SmallVector<OpOperand *> origOutputOperands = llvm::to_vector(llvm::map_range(
+ genericOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; }));
+ SmallVector<OpOperand *> newOutputOperands = llvm::to_vector(llvm::map_range(
+ newOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; }));
+ updateReplacements(origOutputOperands, newOutputOperands,
+ origOutsToNewOutsPos);
+
+ // Drop the unused yield args.
+ if (newOp.getNumDpsInits() != genericOp.getNumDpsInits()) {
+ OpBuilder::InsertionGuard g(rewriter);
+ YieldOp origYieldOp = cast<YieldOp>(origOpBlock->getTerminator());
+ rewriter.setInsertionPoint(origYieldOp);
+
+ SmallVector<Value> newYieldVals(newOp.getNumDpsInits(), nullptr);
+ for (const auto &yieldOpOperands :
+ llvm::enumerate(origYieldOp.getValues())) {
+ auto it = origOutsToNewOutsPos.find(yieldOpOperands.index());
+ if (it == origOutsToNewOutsPos.end())
+ continue;
+ newYieldVals[it->second] = yieldOpOperands.value();
}
- return origToNewPos;
+ rewriter.replaceOpWithNewOp<YieldOp>(origYieldOp, newYieldVals);
}
- // Populate the body of the canonicalized operation.
- void populateOpPayload(
- GenericOp genericOp, GenericOp newOp,
- const llvm::SmallDenseMap<unsigned, unsigned> &origInsToNewInsPos,
- const llvm::SmallDenseMap<unsigned, unsigned> &origOutsToNewOutsPos,
- PatternRewriter &rewriter) const {
- // Merge the body of the original op with the new op.
- Block *newOpBlock = &newOp.getRegion().front();
- assert(newOpBlock->empty() && "expected new op to have an empty payload");
- Block *origOpBlock = &genericOp.getRegion().front();
- SmallVector<Value> replacements(origOpBlock->getNumArguments(), nullptr);
-
- // Replace all arguments in the original op, with arguments from the
- // canonicalized op.
- auto updateReplacements =
- [&](SmallVector<OpOperand *> &origOperands,
- SmallVector<OpOperand *> &newOperands,
- const llvm::SmallDenseMap<unsigned, unsigned> &map) {
- for (const auto &origOperand : llvm::enumerate(origOperands)) {
- auto it = map.find(origOperand.index());
- if (it == map.end())
- continue;
- OpOperand *newOperand = newOperands[it->second];
- replacements[origOperand.value()->getOperandNumber()] =
- newOpBlock->getArgument(newOperand->getOperandNumber());
- }
- };
-
- SmallVector<OpOperand *> origInputOperands =
- genericOp.getDpsInputOperands();
- SmallVector<OpOperand *> newInputOperands = newOp.getDpsInputOperands();
- updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos);
-
- SmallVector<OpOperand *> origOutputOperands =
- llvm::to_vector(llvm::map_range(genericOp.getDpsInitsMutable(),
- [](OpOperand &o) { return &o; }));
- SmallVector<OpOperand *> newOutputOperands =
- llvm::to_vector(llvm::map_range(newOp.getDpsInitsMutable(),
- [](OpOperand &o) { return &o; }));
- updateReplacements(origOutputOperands, newOutputOperands,
...
[truncated]
|
You can test this locally with the following command:git-clang-format --diff aa34a6ab299027ac31929173287e42db0dbdb06b 72f822d83f32d06175cd9817fd60a3250f1c0b6d --extensions cpp,h -- mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp View the diff from clang-format here.diff --git a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp
index 8b05f21bb0..d375878fb2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp
@@ -299,9 +299,8 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
- FailureOr<GenericOp> newOp =
- deduplicateOperandsAndRemoveDeadResults(rewriter, genericOp,
- removeOutputs);
+ FailureOr<GenericOp> newOp = deduplicateOperandsAndRemoveDeadResults(
+ rewriter, genericOp, removeOutputs);
if (failed(newOp) || newOp.value() == genericOp) {
return rewriter.notifyMatchFailure(
genericOp, "failed to dedup operands/remove dead results");
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, I don't see why this should be only exposed as a pattern.
/// Method to deduplicate operands and remove dead results of `linalg.generic` | ||
/// operations. This is effectively DCE for a linalg.generic op. If there is | ||
/// deduplication of operands orremoval of results, replaces the `genericOp` | ||
/// with a new op and returns it. Returns the same operation if there is no | ||
/// deduplication/removal. | ||
FailureOr<linalg::GenericOp> deduplicateOperandsAndRemoveDeadResults( | ||
RewriterBase &rewriter, linalg::GenericOp genericOp, bool removeOutputs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you move this transformation closer to other transformations (above all the populate patterns)? It's more consistent with how this file is structured.
rewriter.replaceOpWithNewOp<YieldOp>(origYieldOp, newYieldVals); | ||
} | ||
rewriter.mergeBlocks(origOpBlock, newOpBlock, replacements); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you move these new static functions to an anonymous namespace?
This functionality was wrapped within a pattern. Expose this as a separate transformations function that can be used outside of pattern rewrite mechanism.