Skip to content

Commit

Permalink
maybeInsertTargetShardingAnnotation accepting reference only
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlimb committed Feb 12, 2025
1 parent 99cf24e commit 3729a86
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 14 deletions.
9 changes: 4 additions & 5 deletions mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,10 @@ Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
// Insert shard op if there is not one that already has the same sharding.
// Use newShardOp if it is not null. Otherwise create a new one.
// May insert resharding if required.
// Return the target ShardOP (new or existing).
ShardOp maybeInsertTargetShardingAnnotation(MeshSharding sharding,
OpOperand &operand,
OpBuilder &builder,
ShardOp newShardOp);
// Potentially updates newShardOp.
void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
OpOperand &operand, OpBuilder &builder,
ShardOp &newShardOp);
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
OpBuilder &builder);
void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
Expand Down
20 changes: 11 additions & 9 deletions mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,10 +275,10 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
return type;
}

ShardOp mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
OpOperand &operand,
OpBuilder &builder,
ShardOp newShardOp) {
void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
OpOperand &operand,
OpBuilder &builder,
ShardOp &newShardOp) {
OpBuilder::InsertionGuard insertionGuard(builder);
Value operandValue = operand.get();
Operation *operandOp = operand.getOwner();
Expand All @@ -287,7 +287,10 @@ ShardOp mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
if (shardOp && sharding == shardOp.getSharding() &&
!shardOp.getAnnotateForUsers()) {
// No need for anything if the correct sharding is already set.
return newShardOp ? newShardOp : shardOp;
if (!newShardOp) {
newShardOp = shardOp;
}
return;
}

if (!newShardOp) {
Expand All @@ -304,23 +307,22 @@ ShardOp mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
});

if (!shardOp || shardOp.getAnnotateForUsers()) {
return newShardOp;
return;
}

auto newShardOp2 = builder.create<ShardOp>(operandValue.getLoc(), newShardOp,
newShardOp.getSharding(),
/*annotate_for_users*/ true);
rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
return newShardOp;
return;
}

void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
OpResult result,
OpBuilder &builder) {
ShardOp newShardOp;
for (auto &use : llvm::make_early_inc_range(result.getUses())) {
newShardOp =
maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp);
maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp);
}
}

Expand Down

0 comments on commit 3729a86

Please sign in to comment.