From ff29772a296d575c2a17e378ccd94c77f39c084e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrzej=20Warzy=C5=84ski?= Date: Wed, 12 Feb 2025 23:12:14 +0000 Subject: [PATCH] [mlir][tensor] Make useful Tensor utilities public (#126802) 1. Extract the main logic from `foldTensorCastPrecondition` into a dedicated helper hook: `hasFoldableTensorCastOperand`. This allows for reusing the corresponding checks. 2. Rename `getNewOperands` to `getUpdatedOperandsAfterCastOpFolding` for better clarity and documentation of its functionality. 3. These updated hooks will be reused in: * https://github.com/llvm/llvm-project/pull/123902. This PR makes them public. **Note:** Moving these hooks to `Tensor/Utils` is not feasible because `MLIRTensorUtils` depends on `MLIRTensorDialect` (CMake targets). If these hooks were moved to `Utils`, it would create a dependency of `MLIRTensorDialect` on `MLIRTensorUtils`, leading to a circular dependency. --- mlir/include/mlir/Dialect/Tensor/IR/Tensor.h | 12 ++++ mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 67 +++++++++++--------- 2 files changed, 48 insertions(+), 31 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h index 1bd0f6553fc8d..b3ec796a72337 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h +++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h @@ -116,6 +116,18 @@ bool canFoldIntoConsumerOp(CastOp castOp); /// this method provides a check that it is worth doing the canonicalization. bool canFoldIntoProducerOp(CastOp castOp); +/// Return true if any of the operands of `op` is a CastOp that can be folded +/// into its consumer, i.e. `op`. This is effectively a convenience wrapper for +/// `canFoldIntoProducerOp`. +bool hasFoldableTensorCastOperand(Operation *op); + +/// Assuming that `op` contains at least one operand that is a foldable CastOp +/// (i.e. `hasFoldableTensorCastOperand` returns true), calculate the updated +/// operands. +SmallVector +getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, + SmallVector &newResTy); + /// Performs folding of any operand of `op` if it comes from a tensor::CastOp /// that can be folded. LogicalResult foldTensorCast(Operation *op); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index fda6246334e15..03c2f3843f262 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -354,6 +354,35 @@ bool mlir::tensor::canFoldIntoProducerOp(CastOp castOp) { castOp.getType()); } +bool mlir::tensor::hasFoldableTensorCastOperand(Operation *op) { + return llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) { + if (llvm::isa(opOperand.get())) + return false; + auto castOp = opOperand.get().getDefiningOp(); + return castOp && canFoldIntoConsumerOp(castOp); + }); +} + +SmallVector mlir::tensor::getUpdatedOperandsAfterCastOpFolding( + DestinationStyleOpInterface op, SmallVector &newResTy) { + SmallVector newOperands; + newOperands.reserve(op->getNumOperands()); + + assert(hasFoldableTensorCastOperand(op) && "No foldable CastOp operands!"); + + // Assumes that the result has dpsInits followed by nonDpsInits. + int64_t dpsInitIdx = 0; + for (OpOperand &opOperand : op->getOpOperands()) { + auto tensorCastOp = opOperand.get().getDefiningOp(); + bool fold = canFoldIntoConsumerOp(tensorCastOp); + newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get()); + if (op.isDpsInit(&opOperand) && + !llvm::isa(newOperands.back().getType())) + newResTy[dpsInitIdx++] = newOperands.back().getType(); + } + return newOperands; +} + /// Performs folding of any operand of `op` if it comes from a tensor::CastOp /// that can be folded. LogicalResult mlir::tensor::foldTensorCast(Operation *op) { @@ -4777,34 +4806,7 @@ bool foldTensorCastPrecondition(DestinationStyleOpInterface op) { isa(op.getOperation())) return false; - // If no operand comes from a tensor::CastOp and can be folded then fail. - bool hasTensorCastOperand = - llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) { - if (llvm::isa(opOperand.get())) - return false; - auto castOp = opOperand.get().getDefiningOp(); - return castOp && canFoldIntoConsumerOp(castOp); - }); - - return hasTensorCastOperand; -} - -static SmallVector getNewOperands(DestinationStyleOpInterface op, - SmallVector &newResTy) { - SmallVector newOperands; - newOperands.reserve(op->getNumOperands()); - - // Assumes that the result has dpsInits followed by nonDpsInits. - int64_t dpsInitIdx = 0; - for (OpOperand &opOperand : op->getOpOperands()) { - auto tensorCastOp = opOperand.get().getDefiningOp(); - bool fold = canFoldIntoConsumerOp(tensorCastOp); - newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get()); - if (op.isDpsInit(&opOperand) && - !llvm::isa(newOperands.back().getType())) - newResTy[dpsInitIdx++] = newOperands.back().getType(); - } - return newOperands; + return hasFoldableTensorCastOperand(op); } // Given the (potentially) updated packed type, `newPackedTy`, generates an @@ -4868,7 +4870,8 @@ struct FoldTensorCastPackOp : public OpRewritePattern { return failure(); SmallVector newResultTypes(op->getResultTypes()); - SmallVector newOperands = getNewOperands(op, newResultTypes); + SmallVector newOperands = + getUpdatedOperandsAfterCastOpFolding(op, newResultTypes); // Get the updated mixed-tile-sizes attribute. SmallVector newMixedTileSizes = @@ -4920,7 +4923,8 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern { return failure(); SmallVector newResultTypes(op->getResultTypes()); - SmallVector newOperands = getNewOperands(op, newResultTypes); + SmallVector newOperands = + getUpdatedOperandsAfterCastOpFolding(op, newResultTypes); Value sourceTensor = newOperands[0]; // Get the updated mixed-tile-sizes attribute. @@ -4980,7 +4984,8 @@ struct FoldTensorCastProducerOp return failure(); SmallVector newResultTypes(op->getResultTypes()); - SmallVector newOperands = getNewOperands(op, newResultTypes); + SmallVector newOperands = + getUpdatedOperandsAfterCastOpFolding(op, newResultTypes); // Clone op auto newOp = clone(rewriter, op, newResultTypes, newOperands);