diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h index 7de7842baf98a..fc5cfffea27a7 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h +++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h @@ -203,11 +203,10 @@ Type shardType(Type type, MeshOp mesh, MeshSharding sharding); // Insert shard op if there is not one that already has the same sharding. // Use newShardOp if it is not null. Otherwise create a new one. // May insert resharding if required. -// Return the target ShardOP (new or existing). -ShardOp maybeInsertTargetShardingAnnotation(MeshSharding sharding, - OpOperand &operand, - OpBuilder &builder, - ShardOp newShardOp); +// Potentially updates newShardOp. +void maybeInsertTargetShardingAnnotation(MeshSharding sharding, + OpOperand &operand, OpBuilder &builder, + ShardOp &newShardOp); void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result, OpBuilder &builder); void maybeInsertSourceShardingAnnotation(MeshSharding sharding, diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp index 561b1ef3b1c39..12e1ec6d717ea 100644 --- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp +++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp @@ -275,10 +275,10 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) { return type; } -ShardOp mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, - OpOperand &operand, - OpBuilder &builder, - ShardOp newShardOp) { +void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, + OpOperand &operand, + OpBuilder &builder, + ShardOp &newShardOp) { OpBuilder::InsertionGuard insertionGuard(builder); Value operandValue = operand.get(); Operation *operandOp = operand.getOwner(); @@ -287,7 +287,10 @@ ShardOp mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, if (shardOp && sharding == shardOp.getSharding() && !shardOp.getAnnotateForUsers()) { // No need for anything if the correct sharding is already set. - return newShardOp ? newShardOp : shardOp; + if (!newShardOp) { + newShardOp = shardOp; + } + return; } if (!newShardOp) { @@ -304,14 +307,14 @@ ShardOp mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, }); if (!shardOp || shardOp.getAnnotateForUsers()) { - return newShardOp; + return; } auto newShardOp2 = builder.create(operandValue.getLoc(), newShardOp, newShardOp.getSharding(), /*annotate_for_users*/ true); rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2); - return newShardOp; + return; } void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, @@ -319,8 +322,7 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpBuilder &builder) { ShardOp newShardOp; for (auto &use : llvm::make_early_inc_range(result.getUses())) { - newShardOp = - maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp); + maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp); } }