Skip to content

Commit

Permalink
sharding propagation: add only one shardop for each result
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlimb committed Feb 10, 2025
1 parent 6a63361 commit d52ca9a
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 16 deletions.
9 changes: 6 additions & 3 deletions mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,13 @@ ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
Type shardType(Type type, MeshOp mesh, MeshSharding sharding);

// Insert shard op if there is not one that already has the same sharding.
// Use newShardOp if it is not null. Otherwise create a new one.
// May insert resharding if required.
void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
OpOperand &operand,
OpBuilder &builder);
// Return the target ShardOP (new or existing).
ShardOp maybeInsertTargetShardingAnnotation(MeshSharding sharding,
OpOperand &operand,
OpBuilder &builder,
ShardOp newShardOp);
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
OpBuilder &builder);
void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
Expand Down
33 changes: 20 additions & 13 deletions mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,10 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
return type;
}

void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
OpOperand &operand,
OpBuilder &builder) {
ShardOp mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
OpOperand &operand,
OpBuilder &builder,
ShardOp newShardOp) {
OpBuilder::InsertionGuard insertionGuard(builder);
Value operandValue = operand.get();
Operation *operandOp = operand.getOwner();
Expand All @@ -286,34 +287,40 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
if (shardOp && sharding == shardOp.getSharding() &&
!shardOp.getAnnotateForUsers()) {
// No need for anything the correct sharding is already set.
return;
return newShardOp ? newShardOp : shardOp;
}

auto shardingOp = builder.create<ShardingOp>(operandValue.getLoc(), sharding);
auto newShardOp =
builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
/*annotate_for_users*/ false);
if (!newShardOp) {
auto shardingOp =
builder.create<ShardingOp>(operandValue.getLoc(), sharding);
newShardOp =
builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
/*annotate_for_users*/ false);
}
IRRewriter rewriter(builder);
rewriter.replaceUsesWithIf(
operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
return use.getOwner() == operandOp && use.get() == operandValue;
});

if (!shardOp || shardOp.getAnnotateForUsers()) {
return;
return newShardOp;
}

auto newShardOp2 =
builder.create<ShardOp>(operandValue.getLoc(), newShardOp, shardingOp,
/*annotate_for_users*/ true);
auto newShardOp2 = builder.create<ShardOp>(operandValue.getLoc(), newShardOp,
newShardOp.getSharding(),
/*annotate_for_users*/ true);
rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
return newShardOp;
}

void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
OpResult result,
OpBuilder &builder) {
ShardOp newShardOp;
for (auto &use : llvm::make_early_inc_range(result.getUses())) {
maybeInsertTargetShardingAnnotation(sharding, use, builder);
newShardOp =
maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp);
}
}

Expand Down

0 comments on commit d52ca9a

Please sign in to comment.