-
Notifications
You must be signed in to change notification settings - Fork 12.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][tensor] Make useful Tensor utilities public #126802
[mlir][tensor] Make useful Tensor utilities public #126802
Conversation
1. Extract the main logic from `foldTensorCastPrecondition` into a dedicated helper hook: `hasFoldableTensorCastOperand`. This allows for reusing the corresponding checks. 2. Rename `getNewOperands` to `getUpdatedOperandsAfterCastOpFolding` for better clarity and documentation of its functionality. 3. These updated hooks will be reused in: * llvm#123902. This PR makes them public. **Note:** Moving these hooks to `Tensor/Utils` is not feasible because `MLIRTensorUtils` depends on `MLIRTensorDialect` (CMake targets). If these hooks were moved to `Utils`, it would create a dependency of `MLIRTensorDialect` on `MLIRTensorUtils`, leading to a circular dependency.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tensor Author: Andrzej Warzyński (banach-space) Changes
Note: Moving these hooks to Full diff: https://github.com/llvm/llvm-project/pull/126802.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index 1bd0f6553fc8d..b3ec796a72337 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -116,6 +116,18 @@ bool canFoldIntoConsumerOp(CastOp castOp);
/// this method provides a check that it is worth doing the canonicalization.
bool canFoldIntoProducerOp(CastOp castOp);
+/// Return true if any of the operands of `op` is a CastOp that can be folded
+/// into its consumer, i.e. `op`. This is effectively a convenience wrapper for
+/// `canFoldIntoProducerOp`.
+bool hasFoldableTensorCastOperand(Operation *op);
+
+/// Assuming that `op` contains at least one operand that is a foldable CastOp
+/// (i.e. `hasFoldableTensorCastOperand` returns true), calculate the updated
+/// operands.
+SmallVector<Value>
+getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op,
+ SmallVector<Type> &newResTy);
+
/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
/// that can be folded.
LogicalResult foldTensorCast(Operation *op);
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index fda6246334e15..03c2f3843f262 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -354,6 +354,35 @@ bool mlir::tensor::canFoldIntoProducerOp(CastOp castOp) {
castOp.getType());
}
+bool mlir::tensor::hasFoldableTensorCastOperand(Operation *op) {
+ return llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
+ if (llvm::isa<BlockArgument>(opOperand.get()))
+ return false;
+ auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
+ return castOp && canFoldIntoConsumerOp(castOp);
+ });
+}
+
+SmallVector<Value> mlir::tensor::getUpdatedOperandsAfterCastOpFolding(
+ DestinationStyleOpInterface op, SmallVector<Type> &newResTy) {
+ SmallVector<Value> newOperands;
+ newOperands.reserve(op->getNumOperands());
+
+ assert(hasFoldableTensorCastOperand(op) && "No foldable CastOp operands!");
+
+ // Assumes that the result has dpsInits followed by nonDpsInits.
+ int64_t dpsInitIdx = 0;
+ for (OpOperand &opOperand : op->getOpOperands()) {
+ auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
+ bool fold = canFoldIntoConsumerOp(tensorCastOp);
+ newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
+ if (op.isDpsInit(&opOperand) &&
+ !llvm::isa<MemRefType>(newOperands.back().getType()))
+ newResTy[dpsInitIdx++] = newOperands.back().getType();
+ }
+ return newOperands;
+}
+
/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
/// that can be folded.
LogicalResult mlir::tensor::foldTensorCast(Operation *op) {
@@ -4777,34 +4806,7 @@ bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
isa<LoopLikeOpInterface>(op.getOperation()))
return false;
- // If no operand comes from a tensor::CastOp and can be folded then fail.
- bool hasTensorCastOperand =
- llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
- if (llvm::isa<BlockArgument>(opOperand.get()))
- return false;
- auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
- return castOp && canFoldIntoConsumerOp(castOp);
- });
-
- return hasTensorCastOperand;
-}
-
-static SmallVector<Value> getNewOperands(DestinationStyleOpInterface op,
- SmallVector<Type> &newResTy) {
- SmallVector<Value> newOperands;
- newOperands.reserve(op->getNumOperands());
-
- // Assumes that the result has dpsInits followed by nonDpsInits.
- int64_t dpsInitIdx = 0;
- for (OpOperand &opOperand : op->getOpOperands()) {
- auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
- bool fold = canFoldIntoConsumerOp(tensorCastOp);
- newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
- if (op.isDpsInit(&opOperand) &&
- !llvm::isa<MemRefType>(newOperands.back().getType()))
- newResTy[dpsInitIdx++] = newOperands.back().getType();
- }
- return newOperands;
+ return hasFoldableTensorCastOperand(op);
}
// Given the (potentially) updated packed type, `newPackedTy`, generates an
@@ -4868,7 +4870,8 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
return failure();
SmallVector<Type> newResultTypes(op->getResultTypes());
- SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
+ SmallVector<Value> newOperands =
+ getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
// Get the updated mixed-tile-sizes attribute.
SmallVector<OpFoldResult> newMixedTileSizes =
@@ -4920,7 +4923,8 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
return failure();
SmallVector<Type> newResultTypes(op->getResultTypes());
- SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
+ SmallVector<Value> newOperands =
+ getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
Value sourceTensor = newOperands[0];
// Get the updated mixed-tile-sizes attribute.
@@ -4980,7 +4984,8 @@ struct FoldTensorCastProducerOp
return failure();
SmallVector<Type> newResultTypes(op->getResultTypes());
- SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
+ SmallVector<Value> newOperands =
+ getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
// Clone op
auto newOp = clone(rewriter, op, newResultTypes, newOperands);
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! This makes sense to me. It's similar to existing patterns in LLVM wrt folding casts to match patterns, so LGTM.
The very minor redundancy between the two new functions cannot be commoned up in any reasonable way that will both be smaller and look simpler, so not a problem either.
Thank you!
1. Extract the main logic from `foldTensorCastPrecondition` into a dedicated helper hook: `hasFoldableTensorCastOperand`. This allows for reusing the corresponding checks. 2. Rename `getNewOperands` to `getUpdatedOperandsAfterCastOpFolding` for better clarity and documentation of its functionality. 3. These updated hooks will be reused in: * llvm#123902. This PR makes them public. **Note:** Moving these hooks to `Tensor/Utils` is not feasible because `MLIRTensorUtils` depends on `MLIRTensorDialect` (CMake targets). If these hooks were moved to `Utils`, it would create a dependency of `MLIRTensorDialect` on `MLIRTensorUtils`, leading to a circular dependency.
1. Extract the main logic from `foldTensorCastPrecondition` into a dedicated helper hook: `hasFoldableTensorCastOperand`. This allows for reusing the corresponding checks. 2. Rename `getNewOperands` to `getUpdatedOperandsAfterCastOpFolding` for better clarity and documentation of its functionality. 3. These updated hooks will be reused in: * llvm#123902. This PR makes them public. **Note:** Moving these hooks to `Tensor/Utils` is not feasible because `MLIRTensorUtils` depends on `MLIRTensorDialect` (CMake targets). If these hooks were moved to `Utils`, it would create a dependency of `MLIRTensorDialect` on `MLIRTensorUtils`, leading to a circular dependency.
1. Extract the main logic from `foldTensorCastPrecondition` into a dedicated helper hook: `hasFoldableTensorCastOperand`. This allows for reusing the corresponding checks. 2. Rename `getNewOperands` to `getUpdatedOperandsAfterCastOpFolding` for better clarity and documentation of its functionality. 3. These updated hooks will be reused in: * llvm#123902. This PR makes them public. **Note:** Moving these hooks to `Tensor/Utils` is not feasible because `MLIRTensorUtils` depends on `MLIRTensorDialect` (CMake targets). If these hooks were moved to `Utils`, it would create a dependency of `MLIRTensorDialect` on `MLIRTensorUtils`, leading to a circular dependency.
Extract the main logic from
foldTensorCastPrecondition
into a dedicatedhelper hook:
hasFoldableTensorCastOperand
. This allows for reusing thecorresponding checks.
Rename
getNewOperands
togetUpdatedOperandsAfterCastOpFolding
for betterclarity and documentation of its functionality.
These updated hooks will be reused in:
This PR makes them public.
Note: Moving these hooks to
Tensor/Utils
is not feasible becauseMLIRTensorUtils
depends onMLIRTensorDialect
(CMake targets). If thesehooks were moved to
Utils
, it would create a dependency ofMLIRTensorDialect
on
MLIRTensorUtils
, leading to a circular dependency.