From 3c76df3d4552953b0d5fa6719b31d68796fda199 Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Mon, 13 Jan 2025 15:36:54 +0100 Subject: [PATCH] Adding sharding extraction operation and op tests and handling GetShardingOp in ShardingPropagation --- mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 21 ++++++++++++++++++ mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 22 ++++++++++++++----- .../Mesh/Transforms/ShardingPropagation.cpp | 2 +- .../Dialect/Mesh/Transforms/Spmdization.cpp | 9 ++++++++ mlir/test/Dialect/Mesh/ops.mlir | 10 +++++++++ mlir/test/Dialect/Mesh/spmdization.mlir | 14 ++++++++++++ 6 files changed, 71 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td index 531020930768e..031e6f63bcb42 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td +++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td @@ -318,12 +318,33 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [ "ArrayRef":$split_axes, "::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes, "::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_offsets)>, + OpBuilder<(ins "llvm::StringRef":$mesh, + "ArrayRef":$split_axes, + CArg<"ArrayRef", "{}">:$static_halo_sizes, + CArg<"ArrayRef", "{}">:$static_sharded_dims_offsets + )>, OpBuilder<(ins "mlir::mesh::MeshSharding":$from)> ]; let hasVerifier = 1; let hasCanonicalizer = 1; } +def Mesh_GetShardingOp : Mesh_Op<"get_sharding", [Pure]> { + let summary = "Get the sharding of the given tensor."; + let description = [{ + This operation returns the sharding of the given tensor as a MeshSharding. + }]; + let arguments = (ins + AnyRankedTensor:$source + ); + let results = (outs + Mesh_Sharding:$result + ); + let assemblyFormat = [{ + $source attr-dict `:` type($source) `->` type($result) + }]; +} + def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> { let summary = "Get the shard shape of a given process/device."; let description = [{ diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp index 2fff67c44a8ac..f84d467048522 100644 --- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp +++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp @@ -454,16 +454,14 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, ArrayRef split_axes, ArrayRef partial_axes, mesh::ReductionKind partial_type, - ArrayRef static_halo_sizes, - ArrayRef static_sharded_dims_offsets) { + ArrayRef static_halos, + ArrayRef static_offsets) { return build( b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), ::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes), ::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type), - ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halo_sizes), {}, - ::mlir::DenseI64ArrayAttr::get(b.getContext(), - static_sharded_dims_offsets), - {}); + ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {}, + ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {}); } void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, @@ -475,6 +473,18 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, {}, {}, {}, {}); } +void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, + llvm::StringRef mesh, ArrayRef split_axes, + ArrayRef static_halos, + ArrayRef static_offsets) { + return build( + b, odsState, FlatSymbolRefAttr::get(b.getContext(), mesh), + MeshAxesArrayAttr::get(b.getContext(), split_axes), {}, + ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum), + ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {}, + ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {}); +} + void ShardingOp::build( ::mlir::OpBuilder &b, ::mlir::OperationState &odsState, FlatSymbolRefAttr mesh, ArrayRef split_axes, diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp index f96d54424a2fe..8c989cce63406 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp @@ -285,7 +285,7 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) { ShardingInterface shardingOp = llvm::dyn_cast(op); if (op->hasTrait() || (op->hasTrait() && !shardingOp) || - llvm::isa(op)) + llvm::isa(op)) return success(); if (!shardingOp) { diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp index e6fe0fd5d1e87..4ec8bbc0dff7d 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp @@ -738,6 +738,15 @@ spmdizeOperation(Operation &op, IRMapping &spmdizationMap, if (isa(op)) { return success(); } + if (auto getShardingOp = dyn_cast(op)) { + auto shardOp = getShardingOp.getSource().getDefiningOp(); + if (!shardOp) { + return op.emitError("expected a shard op as source of get_sharding"); + } + auto newSharding = builder.clone(*shardOp.getSharding().getDefiningOp()); + spmdizationMap.map(op.getResult(0), newSharding->getResult(0)); + return success(); + } ShardOp shardOp = llvm::dyn_cast(op); if (shardOp) { diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir index 978de4939ee77..dae21655afb23 100644 --- a/mlir/test/Dialect/Mesh/ops.mlir +++ b/mlir/test/Dialect/Mesh/ops.mlir @@ -164,6 +164,16 @@ func.func @mesh_shard_shape() { return } +// CHECK-LABEL: func @mesh_get_sharding +// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32> +func.func @mesh_get_sharding(%arg0 : tensor<4x8xf32>) -> !mesh.sharding { + // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh1 split_axes = {{\[\[}}], [0]] : !mesh.sharding + %s = mesh.sharding @mesh1 split_axes = [[], [0]] : !mesh.sharding + // CHECK-NEXT: mesh.get_sharding %[[ARG]] : tensor<4x8xf32> -> !mesh.sharding + %0 = mesh.get_sharding %arg0 : tensor<4x8xf32> -> !mesh.sharding + return %0 : !mesh.sharding +} + // CHECK-LABEL: func @mesh_shape func.func @mesh_shape() -> (index, index) { // CHECK: %[[RES:.*]]:2 = mesh.mesh_shape @mesh0 axes = [0, 1] : index, index diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir index c1b96fda0f4a7..59f7162e21013 100644 --- a/mlir/test/Dialect/Mesh/spmdization.mlir +++ b/mlir/test/Dialect/Mesh/spmdization.mlir @@ -4,6 +4,20 @@ mesh.mesh @mesh_1d(shape = 2) +// CHECK-LABEL: func @return_sharding +func.func @return_sharding( + // CHECK-SAME: [[ARG:%.*]]: tensor<1xf32> + %arg0: tensor<2xf32> +// CHECK-SAME: ) -> (tensor<1xf32>, !mesh.sharding) { +) -> (tensor<2xf32>, !mesh.sharding) { + %ssharding_annotated = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding + %sharding_annotated = mesh.shard %arg0 to %ssharding_annotated : tensor<2xf32> + // CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}0]] : !mesh.sharding + %r = mesh.get_sharding %sharding_annotated : tensor<2xf32> -> !mesh.sharding + // CHECK-NEXT: return [[ARG]], [[vsharding]] : tensor<1xf32>, !mesh.sharding + return %sharding_annotated, %r : tensor<2xf32>, !mesh.sharding +} + // CHECK-LABEL: func @full_replication func.func @full_replication( // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>