Skip to content

Commit

Permalink
[mlir][tensor] Make useful Tensor utilities public (llvm#126802)
Browse files Browse the repository at this point in the history
1. Extract the main logic from `foldTensorCastPrecondition` into a
   dedicated helper hook: `hasFoldableTensorCastOperand`. This allows
   for reusing the corresponding checks.

2. Rename `getNewOperands` to `getUpdatedOperandsAfterCastOpFolding` for
   better clarity and documentation of its functionality.

3. These updated hooks will be reused in:
   * llvm#123902. This PR makes
     them public.

**Note:** Moving these hooks to `Tensor/Utils` is not feasible because
`MLIRTensorUtils` depends on `MLIRTensorDialect` (CMake targets). If
these hooks were moved to `Utils`, it would create a dependency of
`MLIRTensorDialect` on `MLIRTensorUtils`, leading to a circular
dependency.
  • Loading branch information
banach-space authored and flovent committed Feb 13, 2025
1 parent 0da960f commit 0a4fb00
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 31 deletions.
12 changes: 12 additions & 0 deletions mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,18 @@ bool canFoldIntoConsumerOp(CastOp castOp);
/// this method provides a check that it is worth doing the canonicalization.
bool canFoldIntoProducerOp(CastOp castOp);

/// Return true if any of the operands of `op` is a CastOp that can be folded
/// into its consumer, i.e. `op`. This is effectively a convenience wrapper for
/// `canFoldIntoProducerOp`.
bool hasFoldableTensorCastOperand(Operation *op);

/// Assuming that `op` contains at least one operand that is a foldable CastOp
/// (i.e. `hasFoldableTensorCastOperand` returns true), calculate the updated
/// operands.
SmallVector<Value>
getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op,
SmallVector<Type> &newResTy);

/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
/// that can be folded.
LogicalResult foldTensorCast(Operation *op);
Expand Down
67 changes: 36 additions & 31 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,35 @@ bool mlir::tensor::canFoldIntoProducerOp(CastOp castOp) {
castOp.getType());
}

bool mlir::tensor::hasFoldableTensorCastOperand(Operation *op) {
return llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
if (llvm::isa<BlockArgument>(opOperand.get()))
return false;
auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
return castOp && canFoldIntoConsumerOp(castOp);
});
}

SmallVector<Value> mlir::tensor::getUpdatedOperandsAfterCastOpFolding(
DestinationStyleOpInterface op, SmallVector<Type> &newResTy) {
SmallVector<Value> newOperands;
newOperands.reserve(op->getNumOperands());

assert(hasFoldableTensorCastOperand(op) && "No foldable CastOp operands!");

// Assumes that the result has dpsInits followed by nonDpsInits.
int64_t dpsInitIdx = 0;
for (OpOperand &opOperand : op->getOpOperands()) {
auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
bool fold = canFoldIntoConsumerOp(tensorCastOp);
newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
if (op.isDpsInit(&opOperand) &&
!llvm::isa<MemRefType>(newOperands.back().getType()))
newResTy[dpsInitIdx++] = newOperands.back().getType();
}
return newOperands;
}

/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
/// that can be folded.
LogicalResult mlir::tensor::foldTensorCast(Operation *op) {
Expand Down Expand Up @@ -4777,34 +4806,7 @@ bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
isa<LoopLikeOpInterface>(op.getOperation()))
return false;

// If no operand comes from a tensor::CastOp and can be folded then fail.
bool hasTensorCastOperand =
llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
if (llvm::isa<BlockArgument>(opOperand.get()))
return false;
auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
return castOp && canFoldIntoConsumerOp(castOp);
});

return hasTensorCastOperand;
}

static SmallVector<Value> getNewOperands(DestinationStyleOpInterface op,
SmallVector<Type> &newResTy) {
SmallVector<Value> newOperands;
newOperands.reserve(op->getNumOperands());

// Assumes that the result has dpsInits followed by nonDpsInits.
int64_t dpsInitIdx = 0;
for (OpOperand &opOperand : op->getOpOperands()) {
auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
bool fold = canFoldIntoConsumerOp(tensorCastOp);
newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
if (op.isDpsInit(&opOperand) &&
!llvm::isa<MemRefType>(newOperands.back().getType()))
newResTy[dpsInitIdx++] = newOperands.back().getType();
}
return newOperands;
return hasFoldableTensorCastOperand(op);
}

// Given the (potentially) updated packed type, `newPackedTy`, generates an
Expand Down Expand Up @@ -4868,7 +4870,8 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
return failure();

SmallVector<Type> newResultTypes(op->getResultTypes());
SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
SmallVector<Value> newOperands =
getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);

// Get the updated mixed-tile-sizes attribute.
SmallVector<OpFoldResult> newMixedTileSizes =
Expand Down Expand Up @@ -4920,7 +4923,8 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
return failure();

SmallVector<Type> newResultTypes(op->getResultTypes());
SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
SmallVector<Value> newOperands =
getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
Value sourceTensor = newOperands[0];

// Get the updated mixed-tile-sizes attribute.
Expand Down Expand Up @@ -4980,7 +4984,8 @@ struct FoldTensorCastProducerOp
return failure();

SmallVector<Type> newResultTypes(op->getResultTypes());
SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
SmallVector<Value> newOperands =
getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);

// Clone op
auto newOp = clone(rewriter, op, newResultTypes, newOperands);
Expand Down

0 comments on commit 0a4fb00

Please sign in to comment.