Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][tensor] Make useful Tensor utilities public #126802

Merged
merged 1 commit into from
Feb 12, 2025

Conversation

banach-space
Copy link
Contributor

  1. Extract the main logic from foldTensorCastPrecondition into a dedicated
    helper hook: hasFoldableTensorCastOperand. This allows for reusing the
    corresponding checks.

  2. Rename getNewOperands to getUpdatedOperandsAfterCastOpFolding for better
    clarity and documentation of its functionality.

  3. These updated hooks will be reused in:

Note: Moving these hooks to Tensor/Utils is not feasible because
MLIRTensorUtils depends on MLIRTensorDialect (CMake targets). If these
hooks were moved to Utils, it would create a dependency of MLIRTensorDialect
on MLIRTensorUtils, leading to a circular dependency.

1. Extract the main logic from `foldTensorCastPrecondition` into a dedicated
   helper hook: `hasFoldableTensorCastOperand`. This allows for reusing the
   corresponding checks.

2. Rename `getNewOperands` to `getUpdatedOperandsAfterCastOpFolding` for better
   clarity and documentation of its functionality.

3. These updated hooks will be reused in:
   * llvm#123902.
   This PR makes them public.

**Note:** Moving these hooks to `Tensor/Utils` is not feasible because
`MLIRTensorUtils` depends on `MLIRTensorDialect` (CMake targets). If these
hooks were moved to `Utils`, it would create a dependency of `MLIRTensorDialect`
on `MLIRTensorUtils`, leading to a circular dependency.
@llvmbot
Copy link
Member

llvmbot commented Feb 11, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tensor

Author: Andrzej Warzyński (banach-space)

Changes
  1. Extract the main logic from foldTensorCastPrecondition into a dedicated
    helper hook: hasFoldableTensorCastOperand. This allows for reusing the
    corresponding checks.

  2. Rename getNewOperands to getUpdatedOperandsAfterCastOpFolding for better
    clarity and documentation of its functionality.

  3. These updated hooks will be reused in:

Note: Moving these hooks to Tensor/Utils is not feasible because
MLIRTensorUtils depends on MLIRTensorDialect (CMake targets). If these
hooks were moved to Utils, it would create a dependency of MLIRTensorDialect
on MLIRTensorUtils, leading to a circular dependency.


Full diff: https://github.com/llvm/llvm-project/pull/126802.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tensor/IR/Tensor.h (+12)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+36-31)
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index 1bd0f6553fc8d..b3ec796a72337 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -116,6 +116,18 @@ bool canFoldIntoConsumerOp(CastOp castOp);
 /// this method provides a check that it is worth doing the canonicalization.
 bool canFoldIntoProducerOp(CastOp castOp);
 
+/// Return true if any of the operands of `op` is a CastOp that can be folded
+/// into its consumer, i.e. `op`. This is effectively a convenience wrapper for
+/// `canFoldIntoProducerOp`.
+bool hasFoldableTensorCastOperand(Operation *op);
+
+/// Assuming that `op` contains at least one operand that is a foldable CastOp
+/// (i.e. `hasFoldableTensorCastOperand` returns true), calculate the updated
+/// operands.
+SmallVector<Value>
+getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op,
+                                     SmallVector<Type> &newResTy);
+
 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
 /// that can be folded.
 LogicalResult foldTensorCast(Operation *op);
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index fda6246334e15..03c2f3843f262 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -354,6 +354,35 @@ bool mlir::tensor::canFoldIntoProducerOp(CastOp castOp) {
                                     castOp.getType());
 }
 
+bool mlir::tensor::hasFoldableTensorCastOperand(Operation *op) {
+  return llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
+    if (llvm::isa<BlockArgument>(opOperand.get()))
+      return false;
+    auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
+    return castOp && canFoldIntoConsumerOp(castOp);
+  });
+}
+
+SmallVector<Value> mlir::tensor::getUpdatedOperandsAfterCastOpFolding(
+    DestinationStyleOpInterface op, SmallVector<Type> &newResTy) {
+  SmallVector<Value> newOperands;
+  newOperands.reserve(op->getNumOperands());
+
+  assert(hasFoldableTensorCastOperand(op) && "No foldable CastOp operands!");
+
+  // Assumes that the result has dpsInits followed by nonDpsInits.
+  int64_t dpsInitIdx = 0;
+  for (OpOperand &opOperand : op->getOpOperands()) {
+    auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
+    bool fold = canFoldIntoConsumerOp(tensorCastOp);
+    newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
+    if (op.isDpsInit(&opOperand) &&
+        !llvm::isa<MemRefType>(newOperands.back().getType()))
+      newResTy[dpsInitIdx++] = newOperands.back().getType();
+  }
+  return newOperands;
+}
+
 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
 /// that can be folded.
 LogicalResult mlir::tensor::foldTensorCast(Operation *op) {
@@ -4777,34 +4806,7 @@ bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
       isa<LoopLikeOpInterface>(op.getOperation()))
     return false;
 
-  // If no operand comes from a tensor::CastOp and can be folded then fail.
-  bool hasTensorCastOperand =
-      llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
-        if (llvm::isa<BlockArgument>(opOperand.get()))
-          return false;
-        auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
-        return castOp && canFoldIntoConsumerOp(castOp);
-      });
-
-  return hasTensorCastOperand;
-}
-
-static SmallVector<Value> getNewOperands(DestinationStyleOpInterface op,
-                                         SmallVector<Type> &newResTy) {
-  SmallVector<Value> newOperands;
-  newOperands.reserve(op->getNumOperands());
-
-  // Assumes that the result has dpsInits followed by nonDpsInits.
-  int64_t dpsInitIdx = 0;
-  for (OpOperand &opOperand : op->getOpOperands()) {
-    auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
-    bool fold = canFoldIntoConsumerOp(tensorCastOp);
-    newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
-    if (op.isDpsInit(&opOperand) &&
-        !llvm::isa<MemRefType>(newOperands.back().getType()))
-      newResTy[dpsInitIdx++] = newOperands.back().getType();
-  }
-  return newOperands;
+  return hasFoldableTensorCastOperand(op);
 }
 
 // Given the (potentially) updated packed type, `newPackedTy`, generates an
@@ -4868,7 +4870,8 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
       return failure();
 
     SmallVector<Type> newResultTypes(op->getResultTypes());
-    SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
+    SmallVector<Value> newOperands =
+        getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
 
     // Get the updated mixed-tile-sizes attribute.
     SmallVector<OpFoldResult> newMixedTileSizes =
@@ -4920,7 +4923,8 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
       return failure();
 
     SmallVector<Type> newResultTypes(op->getResultTypes());
-    SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
+    SmallVector<Value> newOperands =
+        getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
     Value sourceTensor = newOperands[0];
 
     // Get the updated mixed-tile-sizes attribute.
@@ -4980,7 +4984,8 @@ struct FoldTensorCastProducerOp
       return failure();
 
     SmallVector<Type> newResultTypes(op->getResultTypes());
-    SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
+    SmallVector<Value> newOperands =
+        getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
 
     // Clone op
     auto newOp = clone(rewriter, op, newResultTypes, newOperands);

Copy link
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This makes sense to me. It's similar to existing patterns in LLVM wrt folding casts to match patterns, so LGTM.

The very minor redundancy between the two new functions cannot be commoned up in any reasonable way that will both be smaller and look simpler, so not a problem either.

Thank you!

@banach-space banach-space merged commit 5586541 into llvm:main Feb 12, 2025
11 checks passed
@banach-space banach-space deleted the andrzej/extract_tensor_helper branch February 13, 2025 08:38
flovent pushed a commit to flovent/llvm-project that referenced this pull request Feb 13, 2025
1. Extract the main logic from `foldTensorCastPrecondition` into a
   dedicated helper hook: `hasFoldableTensorCastOperand`. This allows
   for reusing the corresponding checks.

2. Rename `getNewOperands` to `getUpdatedOperandsAfterCastOpFolding` for
   better clarity and documentation of its functionality.

3. These updated hooks will be reused in:
   * llvm#123902. This PR makes
     them public.

**Note:** Moving these hooks to `Tensor/Utils` is not feasible because
`MLIRTensorUtils` depends on `MLIRTensorDialect` (CMake targets). If
these hooks were moved to `Utils`, it would create a dependency of
`MLIRTensorDialect` on `MLIRTensorUtils`, leading to a circular
dependency.
joaosaffran pushed a commit to joaosaffran/llvm-project that referenced this pull request Feb 14, 2025
1. Extract the main logic from `foldTensorCastPrecondition` into a
   dedicated helper hook: `hasFoldableTensorCastOperand`. This allows
   for reusing the corresponding checks.

2. Rename `getNewOperands` to `getUpdatedOperandsAfterCastOpFolding` for
   better clarity and documentation of its functionality.

3. These updated hooks will be reused in:
   * llvm#123902. This PR makes
     them public.

**Note:** Moving these hooks to `Tensor/Utils` is not feasible because
`MLIRTensorUtils` depends on `MLIRTensorDialect` (CMake targets). If
these hooks were moved to `Utils`, it would create a dependency of
`MLIRTensorDialect` on `MLIRTensorUtils`, leading to a circular
dependency.
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
1. Extract the main logic from `foldTensorCastPrecondition` into a
   dedicated helper hook: `hasFoldableTensorCastOperand`. This allows
   for reusing the corresponding checks.

2. Rename `getNewOperands` to `getUpdatedOperandsAfterCastOpFolding` for
   better clarity and documentation of its functionality.

3. These updated hooks will be reused in:
   * llvm#123902. This PR makes
     them public.

**Note:** Moving these hooks to `Tensor/Utils` is not feasible because
`MLIRTensorUtils` depends on `MLIRTensorDialect` (CMake targets). If
these hooks were moved to `Utils`, it would create a dependency of
`MLIRTensorDialect` on `MLIRTensorUtils`, leading to a circular
dependency.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants