Skip to content

Commit

Permalink
Adding sharding extraction operation and op tests
Browse files Browse the repository at this point in the history
and handling GetShardingOp in ShardingPropagation
  • Loading branch information
fschlimb committed Feb 10, 2025
1 parent d52ca9a commit 3c76df3
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 7 deletions.
21 changes: 21 additions & 0 deletions mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -318,12 +318,33 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
"ArrayRef<MeshAxesAttr>":$split_axes,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_offsets)>,
OpBuilder<(ins "llvm::StringRef":$mesh,
"ArrayRef<MeshAxesAttr>":$split_axes,
CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes,
CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_offsets
)>,
OpBuilder<(ins "mlir::mesh::MeshSharding":$from)>
];
let hasVerifier = 1;
let hasCanonicalizer = 1;
}

def Mesh_GetShardingOp : Mesh_Op<"get_sharding", [Pure]> {
let summary = "Get the sharding of the given tensor.";
let description = [{
This operation returns the sharding of the given tensor as a MeshSharding.
}];
let arguments = (ins
AnyRankedTensor:$source
);
let results = (outs
Mesh_Sharding:$result
);
let assemblyFormat = [{
$source attr-dict `:` type($source) `->` type($result)
}];
}

def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
let summary = "Get the shard shape of a given process/device.";
let description = [{
Expand Down
22 changes: 16 additions & 6 deletions mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -454,16 +454,14 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
ArrayRef<MeshAxesAttr> split_axes,
ArrayRef<MeshAxis> partial_axes,
mesh::ReductionKind partial_type,
ArrayRef<int64_t> static_halo_sizes,
ArrayRef<int64_t> static_sharded_dims_offsets) {
ArrayRef<int64_t> static_halos,
ArrayRef<int64_t> static_offsets) {
return build(
b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes),
::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type),
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halo_sizes), {},
::mlir::DenseI64ArrayAttr::get(b.getContext(),
static_sharded_dims_offsets),
{});
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
}

void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
Expand All @@ -475,6 +473,18 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
{}, {}, {}, {});
}

void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
llvm::StringRef mesh, ArrayRef<MeshAxesAttr> split_axes,
ArrayRef<int64_t> static_halos,
ArrayRef<int64_t> static_offsets) {
return build(
b, odsState, FlatSymbolRefAttr::get(b.getContext(), mesh),
MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
}

void ShardingOp::build(
::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes,
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
if (op->hasTrait<OpTrait::IsTerminator>() ||
(op->hasTrait<OpTrait::ConstantLike>() && !shardingOp) ||
llvm::isa<mesh::ShardOp, mesh::ShardingOp>(op))
llvm::isa<mesh::ShardOp, mesh::ShardingOp, mesh::GetShardingOp>(op))
return success();

if (!shardingOp) {
Expand Down
9 changes: 9 additions & 0 deletions mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,15 @@ spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
if (isa<ShardingOp>(op)) {
return success();
}
if (auto getShardingOp = dyn_cast<GetShardingOp>(op)) {
auto shardOp = getShardingOp.getSource().getDefiningOp<ShardOp>();
if (!shardOp) {
return op.emitError("expected a shard op as source of get_sharding");
}
auto newSharding = builder.clone(*shardOp.getSharding().getDefiningOp());
spmdizationMap.map(op.getResult(0), newSharding->getResult(0));
return success();
}

ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
if (shardOp) {
Expand Down
10 changes: 10 additions & 0 deletions mlir/test/Dialect/Mesh/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,16 @@ func.func @mesh_shard_shape() {
return
}

// CHECK-LABEL: func @mesh_get_sharding
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
func.func @mesh_get_sharding(%arg0 : tensor<4x8xf32>) -> !mesh.sharding {
// CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh1 split_axes = {{\[\[}}], [0]] : !mesh.sharding
%s = mesh.sharding @mesh1 split_axes = [[], [0]] : !mesh.sharding
// CHECK-NEXT: mesh.get_sharding %[[ARG]] : tensor<4x8xf32> -> !mesh.sharding
%0 = mesh.get_sharding %arg0 : tensor<4x8xf32> -> !mesh.sharding
return %0 : !mesh.sharding
}

// CHECK-LABEL: func @mesh_shape
func.func @mesh_shape() -> (index, index) {
// CHECK: %[[RES:.*]]:2 = mesh.mesh_shape @mesh0 axes = [0, 1] : index, index
Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Dialect/Mesh/spmdization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,20 @@

mesh.mesh @mesh_1d(shape = 2)

// CHECK-LABEL: func @return_sharding
func.func @return_sharding(
// CHECK-SAME: [[ARG:%.*]]: tensor<1xf32>
%arg0: tensor<2xf32>
// CHECK-SAME: ) -> (tensor<1xf32>, !mesh.sharding) {
) -> (tensor<2xf32>, !mesh.sharding) {
%ssharding_annotated = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
%sharding_annotated = mesh.shard %arg0 to %ssharding_annotated : tensor<2xf32>
// CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}0]] : !mesh.sharding
%r = mesh.get_sharding %sharding_annotated : tensor<2xf32> -> !mesh.sharding
// CHECK-NEXT: return [[ARG]], [[vsharding]] : tensor<1xf32>, !mesh.sharding
return %sharding_annotated, %r : tensor<2xf32>, !mesh.sharding
}

// CHECK-LABEL: func @full_replication
func.func @full_replication(
// CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
Expand Down

0 comments on commit 3c76df3

Please sign in to comment.