Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][tensor] Make useful Tensor utilities public #126802

Merged
merged 1 commit into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,18 @@ bool canFoldIntoConsumerOp(CastOp castOp);
/// this method provides a check that it is worth doing the canonicalization.
bool canFoldIntoProducerOp(CastOp castOp);

/// Return true if any of the operands of `op` is a CastOp that can be folded
/// into its consumer, i.e. `op`. This is effectively a convenience wrapper for
/// `canFoldIntoProducerOp`.
bool hasFoldableTensorCastOperand(Operation *op);

/// Assuming that `op` contains at least one operand that is a foldable CastOp
/// (i.e. `hasFoldableTensorCastOperand` returns true), calculate the updated
/// operands.
SmallVector<Value>
getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op,
SmallVector<Type> &newResTy);

/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
/// that can be folded.
LogicalResult foldTensorCast(Operation *op);
Expand Down
67 changes: 36 additions & 31 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,35 @@ bool mlir::tensor::canFoldIntoProducerOp(CastOp castOp) {
castOp.getType());
}

bool mlir::tensor::hasFoldableTensorCastOperand(Operation *op) {
return llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
if (llvm::isa<BlockArgument>(opOperand.get()))
return false;
auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
return castOp && canFoldIntoConsumerOp(castOp);
});
}

SmallVector<Value> mlir::tensor::getUpdatedOperandsAfterCastOpFolding(
DestinationStyleOpInterface op, SmallVector<Type> &newResTy) {
SmallVector<Value> newOperands;
newOperands.reserve(op->getNumOperands());

assert(hasFoldableTensorCastOperand(op) && "No foldable CastOp operands!");

// Assumes that the result has dpsInits followed by nonDpsInits.
int64_t dpsInitIdx = 0;
for (OpOperand &opOperand : op->getOpOperands()) {
auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
bool fold = canFoldIntoConsumerOp(tensorCastOp);
newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
if (op.isDpsInit(&opOperand) &&
!llvm::isa<MemRefType>(newOperands.back().getType()))
newResTy[dpsInitIdx++] = newOperands.back().getType();
}
return newOperands;
}

/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
/// that can be folded.
LogicalResult mlir::tensor::foldTensorCast(Operation *op) {
Expand Down Expand Up @@ -4777,34 +4806,7 @@ bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
isa<LoopLikeOpInterface>(op.getOperation()))
return false;

// If no operand comes from a tensor::CastOp and can be folded then fail.
bool hasTensorCastOperand =
llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
if (llvm::isa<BlockArgument>(opOperand.get()))
return false;
auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
return castOp && canFoldIntoConsumerOp(castOp);
});

return hasTensorCastOperand;
}

static SmallVector<Value> getNewOperands(DestinationStyleOpInterface op,
SmallVector<Type> &newResTy) {
SmallVector<Value> newOperands;
newOperands.reserve(op->getNumOperands());

// Assumes that the result has dpsInits followed by nonDpsInits.
int64_t dpsInitIdx = 0;
for (OpOperand &opOperand : op->getOpOperands()) {
auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
bool fold = canFoldIntoConsumerOp(tensorCastOp);
newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
if (op.isDpsInit(&opOperand) &&
!llvm::isa<MemRefType>(newOperands.back().getType()))
newResTy[dpsInitIdx++] = newOperands.back().getType();
}
return newOperands;
return hasFoldableTensorCastOperand(op);
}

// Given the (potentially) updated packed type, `newPackedTy`, generates an
Expand Down Expand Up @@ -4868,7 +4870,8 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
return failure();

SmallVector<Type> newResultTypes(op->getResultTypes());
SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
SmallVector<Value> newOperands =
getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);

// Get the updated mixed-tile-sizes attribute.
SmallVector<OpFoldResult> newMixedTileSizes =
Expand Down Expand Up @@ -4920,7 +4923,8 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
return failure();

SmallVector<Type> newResultTypes(op->getResultTypes());
SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
SmallVector<Value> newOperands =
getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
Value sourceTensor = newOperands[0];

// Get the updated mixed-tile-sizes attribute.
Expand Down Expand Up @@ -4980,7 +4984,8 @@ struct FoldTensorCastProducerOp
return failure();

SmallVector<Type> newResultTypes(op->getResultTypes());
SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
SmallVector<Value> newOperands =
getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);

// Clone op
auto newOp = clone(rewriter, op, newResultTypes, newOperands);
Expand Down