From 300f788da10fd08401942c8b238b756222a3c98b Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Wed, 27 Nov 2024 16:38:24 +0100 Subject: [PATCH 01/10] Allowing constant-like operands to ShardingInterface ops Attaching ShardingInterface to arith::ConstantOp --- .../Arith/Transforms/ShardingInterfaceImpl.h | 23 +++++ mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h | 2 +- mlir/include/mlir/InitAllDialects.h | 2 + .../Dialect/Arith/Transforms/CMakeLists.txt | 1 + .../Transforms/ShardingInterfaceImpl.cpp | 99 +++++++++++++++++++ mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 19 +++- .../Mesh/Interfaces/ShardingInterface.cpp | 17 ++-- .../Mesh/Transforms/ShardingPropagation.cpp | 3 +- .../Dialect/Mesh/Transforms/Spmdization.cpp | 35 ++++--- .../Extensions/MeshShardingExtensions.cpp | 15 +-- mlir/test/Dialect/Arith/mesh-spmdize.cpp | 17 ++++ .../Dialect/Arith/sharding-propagation.mlir | 54 ++++++++++ 12 files changed, 251 insertions(+), 36 deletions(-) create mode 100644 mlir/include/mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h create mode 100644 mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp create mode 100644 mlir/test/Dialect/Arith/mesh-spmdize.cpp create mode 100644 mlir/test/Dialect/Arith/sharding-propagation.mlir diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h new file mode 100644 index 0000000000000..5addffbe571be --- /dev/null +++ b/mlir/include/mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h @@ -0,0 +1,23 @@ +//===- ShardingInterfaceImpl.h - ------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_ +#define MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_ + +namespace mlir { + +class DialectRegistry; + +namespace arith { + +void registerShardingInterfaceExternalModels(DialectRegistry ®istry); + +} // namespace arith +} // namespace mlir + +#endif // MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_ diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h index 75cb096130ca6..210b82151ede4 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h +++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h @@ -62,7 +62,7 @@ class MeshSharding { ArrayRef dynamic_halo_sizes_ = {}, ArrayRef dynamic_sharded_dims_offsets_ = {}); ::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; } - ::llvm::StringRef getMesh() const { return mesh.getValue(); } + ::llvm::StringRef getMesh() const { return mesh ? mesh.getValue() : ""; } ArrayRef getSplitAxes() const { return split_axes; } ArrayRef getPartialAxes() const { return partial_axes; } ReductionKind getPartialType() const { return partial_type; } diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h index 0da82825c8287..33bc89279c08c 100644 --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -23,6 +23,7 @@ #include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h" #include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h" #include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" #include "mlir/Dialect/ArmSME/IR/ArmSME.h" #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" @@ -158,6 +159,7 @@ inline void registerAllDialects(DialectRegistry ®istry) { arith::registerBufferDeallocationOpInterfaceExternalModels(registry); arith::registerBufferizableOpInterfaceExternalModels(registry); arith::registerBufferViewFlowOpInterfaceExternalModels(registry); + arith::registerShardingInterfaceExternalModels(registry); arith::registerValueBoundsOpInterfaceExternalModels(registry); bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( registry); diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt index 6149b35befe7d..30dd84aff120f 100644 --- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt @@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRArithTransforms ExpandOps.cpp IntRangeOptimizations.cpp ReifyValueBounds.cpp + ShardingInterfaceImpl.cpp UnsignedWhenEquivalent.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp new file mode 100644 index 0000000000000..fc033294eb01b --- /dev/null +++ b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp @@ -0,0 +1,99 @@ +//===- ShardingInterfaceImpl.cpp ------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h" +#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" +#include "mlir/IR/DialectRegistry.h" +#include "llvm/Support/Debug.h" + +using namespace mlir; +using namespace mlir::arith; +using namespace mlir::mesh; + +namespace { + +// Sharding of arith.empty/arith.splat +struct ConstantShardingInterface + : public ShardingInterface::ExternalModel { + SmallVector getLoopIteratorTypes(Operation *op) const { + auto ndims = 0; + if (auto type = dyn_cast(op->getResult(0).getType())) { + ndims = type.getRank(); + } + return SmallVector(ndims, + utils::IteratorType::parallel); + } + + SmallVector getIndexingMaps(Operation *op) const { + if (auto type = dyn_cast(op->getResult(0).getType())) { + return SmallVector(1, {AffineMap::getMultiDimIdentityMap( + type.getRank(), op->getContext())}); + } + return {}; + } + + // Indicate failure if no result sharding exists. + // Otherwise mirror result sharding if it is a tensor constant. + // Otherwise return replication option. + FailureOr + getShardingOption(Operation *op, ArrayRef operandShardings, + ArrayRef resultShardings) const { + if (!resultShardings[0]) { + return failure(); + } + if (auto type = dyn_cast(op->getResult(0).getType())) { + ShardingArray axesArray(resultShardings[0].getSplitAxes().size()); + for (auto [i, axes] : + llvm::enumerate(resultShardings[0].getSplitAxes())) { + axesArray[i].append(axes.asArrayRef().begin(), axes.asArrayRef().end()); + } + return ShardingOption(axesArray, resultShardings[0].getMeshAttr()); + } + return ShardingOption({}, resultShardings[0].getMeshAttr()); + } + + LogicalResult spmdize(Operation *op, ArrayRef spmdizedOperands, + ArrayRef operandShardings, + ArrayRef resultShardings, + IRMapping &spmdizationMap, + SymbolTableCollection &symbolTable, + OpBuilder &builder) const { + auto cOp = cast(op); + auto value = dyn_cast(cOp.getValue()); + if (value) { + if (!value.isSplat() || !resultShardings[0]) { + // Currently non-splat constants are not supported. + return failure(); + } + auto sharding = resultShardings[0]; + auto newType = cast(shardType( + cOp.getType(), getMesh(op, sharding.getMeshAttr(), symbolTable), + sharding)); + auto newValue = value.resizeSplat(newType); + auto newOp = builder.create(op->getLoc(), newType, newValue); + spmdizationMap.map(op->getResult(0), newOp.getResult()); + spmdizationMap.map(op, newOp.getOperation()); + } else { + // `clone` will populate the mapping of old to new results. + (void)builder.clone(*op, spmdizationMap); + } + return success(); + } +}; +} // namespace + +void mlir::arith::registerShardingInterfaceExternalModels( + DialectRegistry ®istry) { + + registry.addExtension(+[](MLIRContext *ctx, ArithDialect *dialect) { + ConstantOp::template attachInterface(*ctx); + }); +} diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp index 33460ff25e9e4..352bf476e3f57 100644 --- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp +++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp @@ -316,9 +316,13 @@ void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding, OpBuilder &builder) { OpBuilder::InsertionGuard insertionGuard(builder); Value operandValue = operand.get(); - Operation *operandOp = operand.getOwner(); Operation *operandSrcOp = operandValue.getDefiningOp(); bool isBlockArg = !operandSrcOp; + if(!isBlockArg && operandSrcOp->hasTrait()) { + return; + } + + Operation *operandOp = operand.getOwner(); ShardOp shardOp = dyn_cast_or_null(operandSrcOp); if (shardOp && sharding == shardOp.getSharding() && @@ -710,8 +714,13 @@ bool MeshSharding::operator!=(const MeshSharding &rhs) const { MeshSharding::MeshSharding(Value rhs) { auto shardingOp = mlir::dyn_cast(rhs.getDefiningOp()); assert(shardingOp && "expected sharding op"); - *this = get(shardingOp.getMeshAttr(), shardingOp.getSplitAxes().getAxes(), - shardingOp.getPartialAxes().value_or(ArrayRef()), + auto splitAxes = shardingOp.getSplitAxes().getAxes(); + auto partialAxes = shardingOp.getPartialAxes().value_or(ArrayRef()); + if(splitAxes.empty() && partialAxes.empty()) { + *this = MeshSharding(); + return; + } + *this = get(shardingOp.getMeshAttr(), splitAxes, partialAxes, shardingOp.getPartialType().value_or(ReductionKind::Sum), shardingOp.getStaticHaloSizes(), shardingOp.getStaticShardedDimsOffsets(), @@ -727,6 +736,10 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef static_sharded_dims_offsets_, ArrayRef dynamic_halo_sizes_, ArrayRef dynamic_sharded_dims_offsets_) { + if(split_axes_.empty() && partial_axes_.empty()) { + return MeshSharding(); + } + MeshSharding res; res.mesh = mesh_; res.split_axes.resize(split_axes_.size()); diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp index c1f4d563d5b42..aae2d4ccfeed9 100644 --- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp +++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp @@ -168,16 +168,16 @@ LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() { // check operands and results type for (Type type : op->getOperandTypes()) - if (!llvm::isa(type)) + if (!llvm::isa(type) && !type.isIntOrIndexOrFloat()) return failure(); for (Type type : op->getResultTypes()) - if (!llvm::isa(type)) + if (!llvm::isa(type) && !type.isIntOrIndexOrFloat()) return failure(); // check loop types - SmallVector loopTypes = getLoopIteratorTypes(); - if (loopTypes.empty()) - return failure(); + // SmallVector loopTypes = getLoopIteratorTypes(); + // if (loopTypes.empty()) + // return failure(); // check maps SmallVector maps = getIndexingMaps(); @@ -448,7 +448,12 @@ static FailureOr getSharding(OpOperand &opOperand, const ShardingOption &shardingOption, AffineMap map) { Value operandValue = opOperand.get(); - auto operandType = cast(operandValue.getType()); + auto operandType = dyn_cast(operandValue.getType()); + if(!operandType) { + if(operandValue.getType().isIntOrIndexOrFloat()) + return MeshSharding(); + return failure(); + } SmallVector> splitAxes(operandType.getRank()); unsigned numDims = map.getNumDims(); for (auto it : llvm::enumerate(map.getResults())) { diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp index 4bd3b425219c1..f96d54424a2fe 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp @@ -282,11 +282,12 @@ static FailureOr selectShardingOption( // a `mesh.shard` operation for all remaining operands and results that do not // have sharding annotations. static LogicalResult visitOp(Operation *op, OpBuilder &builder) { + ShardingInterface shardingOp = llvm::dyn_cast(op); if (op->hasTrait() || + (op->hasTrait() && !shardingOp) || llvm::isa(op)) return success(); - ShardingInterface shardingOp = llvm::dyn_cast(op); if (!shardingOp) { op->emitOpError() << "sharding interface is not implemented."; return failure(); diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp index 327ea0991e4e1..04932f11e6b43 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp @@ -636,14 +636,6 @@ shardedBlockArgumentTypes(Block &block, return res; } -void spmdizeTriviallyShardableOperation(Operation &op, - ArrayRef spmdizedOperands, - ArrayRef operandShardings, - ArrayRef resultShardings, - IRMapping &spmdizationMap, - SymbolTableCollection &symbolTable, - OpBuilder &builder); - static LogicalResult spmdizeOperation( Operation &op, ArrayRef spmdizedOperands, ArrayRef operandShardings, @@ -697,14 +689,15 @@ static std::vector getResultShardings(Operation &op) { std::vector res; res.reserve(op.getNumResults()); llvm::transform(op.getResults(), std::back_inserter(res), - [](OpResult result) { + [&op](OpResult result) { TypedValue rankedTensor = dyn_cast>(result); - if (!rankedTensor) { + if (!rankedTensor || op.hasTrait()) { + return MeshSharding(); + } + if (!result.hasOneUse()) { return MeshSharding(); } - - assert(result.hasOneUse()); Operation *userOp = *result.getUsers().begin(); ShardOp shardOp = llvm::cast(userOp); return MeshSharding(shardOp.getSharding()); @@ -765,6 +758,7 @@ spmdizeOperation(Operation &op, IRMapping &spmdizationMap, static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) { + SmallVector argLocations; llvm::transform(block.getArguments(), std::back_inserter(argLocations), [](BlockArgument arg) { return arg.getLoc(); }); @@ -796,8 +790,12 @@ spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap, // Snapshot the original blocks to not mess up the iteration when adding new // blocks. SmallVector originalBlocks; - llvm::transform(op.getBlocks(), std::back_inserter(originalBlocks), - [](Block &b) { return &b; }); + for (Block &b : op.getBlocks()) { + if (llvm::any_of(b.getOperations(), + [](Operation &op) { return isa(op); })) { + originalBlocks.push_back(&b); + } + } for (Block *block : originalBlocks) { if (failed(spmdizeBlock(*block, spmdizationMap, symbolTableCollection, @@ -823,10 +821,11 @@ spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap, break; } } - assert(returnOp); - op.setType(FunctionType::get(op->getContext(), - op.getFunctionBody().front().getArgumentTypes(), - returnOp->getOperandTypes())); + if (returnOp) { + op.setType(FunctionType::get( + op->getContext(), op.getFunctionBody().front().getArgumentTypes(), + returnOp->getOperandTypes())); + } return success(); } diff --git a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp index f3e72abe7516e..6bb5d4a66f39e 100644 --- a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp +++ b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp @@ -22,10 +22,10 @@ using namespace mlir::mesh; namespace { -// Sharding of tensor.empty -struct EmptyOpShardingInterface - : public ShardingInterface::ExternalModel { +// Sharding of tensor.empty/tensor.splat +template +struct CreatorOpShardingInterface + : public ShardingInterface::ExternalModel, OpTy> { SmallVector getLoopIteratorTypes(Operation *op) const { auto ndims = mlir::cast(op->getResult(0).getType()).getRank(); return SmallVector(ndims, @@ -38,7 +38,7 @@ struct EmptyOpShardingInterface auto type = dyn_cast(val.getType()); if (!type) return {}; - return {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)}; + return SmallVector(op->getNumOperands() + op->getNumResults(), {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)}); } LogicalResult spmdize(Operation *op, ArrayRef spmdizedOperands, @@ -83,7 +83,7 @@ struct EmptyOpShardingInterface } } newOp = - builder.create(op->getLoc(), shardType, newOperands); + builder.create(op->getLoc(), shardType, newOperands); spmdizationMap.map(op->getResult(0), newOp->getResult(0)); } else { // `clone` will populate the mapping of old to new results. @@ -100,6 +100,7 @@ void mlir::tensor::registerShardingInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) { - EmptyOp::template attachInterface(*ctx); + EmptyOp::template attachInterface>(*ctx); + SplatOp::template attachInterface>(*ctx); }); } diff --git a/mlir/test/Dialect/Arith/mesh-spmdize.cpp b/mlir/test/Dialect/Arith/mesh-spmdize.cpp new file mode 100644 index 0000000000000..0688e14b1cf72 --- /dev/null +++ b/mlir/test/Dialect/Arith/mesh-spmdize.cpp @@ -0,0 +1,17 @@ +// RUN: mlir-opt \ +// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization))" \ +// RUN: %s | FileCheck %s + +mesh.mesh @mesh4x4(shape = 4x4) + +// CHECK-LABEL: func @test_spmdize_constant +// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<256x1024xf32> +// CHECK-NEXT: [[vc434_i32:%.*]] = arith.constant 434 : i32 +// CHECK-NEXT: return [[vcst]] : tensor<256x1024xf32> +func.func @test_spmdize_constant() -> (tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} { + %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding + %sharding_annotated_1 = mesh.shard %cst to %sharding_1 : tensor<1024x1024xf32> + %ci = arith.constant 434 : i32 + return %sharding_annotated_1 : tensor<1024x1024xf32> +} diff --git a/mlir/test/Dialect/Arith/sharding-propagation.mlir b/mlir/test/Dialect/Arith/sharding-propagation.mlir new file mode 100644 index 0000000000000..19eb340549b0b --- /dev/null +++ b/mlir/test/Dialect/Arith/sharding-propagation.mlir @@ -0,0 +1,54 @@ +// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation))" %s | FileCheck %s + +mesh.mesh @mesh4x4(shape = 4x4) + +// CHECK-LABEL: func.func @test_shard_constant() -> tensor<1024x1024xf32> attributes {llvm.emit_c_interface} { +// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding +// CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32> +// CHECK-NEXT: [[vcst_0:%.*]] = arith.constant 4.340000e+01 : f32 +// CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding_1:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding +// CHECK-NEXT: [[vsharding_annotated_2:%.*]] = mesh.shard [[v0]] to [[vsharding_1]] : tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding_3:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding +// CHECK-NEXT: [[vsharding_annotated_4:%.*]] = mesh.shard [[vsharding_annotated]] to [[vsharding_3]] annotate_for_users : tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding_5:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding +// CHECK-NEXT: [[vsharding_annotated_6:%.*]] = mesh.shard [[vsharding_annotated_2]] to [[vsharding_5]] annotate_for_users : tensor<1024x1024xf32> +// CHECK-NEXT: [[v1:%.*]] = linalg.add ins([[vsharding_annotated_4]], [[vcst_0]] : tensor<1024x1024xf32>, f32) outs([[vsharding_annotated_6]] : tensor<1024x1024xf32>) -> tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding_7:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding +// CHECK-NEXT: [[vsharding_annotated_8:%.*]] = mesh.shard [[v1]] to [[vsharding_7]] : tensor<1024x1024xf32> +// CHECK-NEXT: return [[vsharding_annotated_8]] : tensor<1024x1024xf32> +func.func @test_shard_constant() -> (tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} { + %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding + %sharding_annotated_1 = mesh.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32> + %ci = arith.constant 43.4e+00 : f32 + %o1 = tensor.empty() : tensor<1024x1024xf32> + %res = linalg.add ins(%sharding_annotated_1, %ci : tensor<1024x1024xf32>, f32) outs(%o1 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32> + return %res : tensor<1024x1024xf32> +} + +// CHECK-LABEL: func.func @test_shard_constant_back() -> tensor<1024x1024xf32> attributes {llvm.emit_c_interface} { +// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding +// CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32> +// CHECK-NEXT: [[vcst_0:%.*]] = arith.constant 4.340000e+01 : f32 +// CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding_1:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding +// CHECK-NEXT: [[vsharding_annotated_2:%.*]] = mesh.shard [[v0]] to [[vsharding_1]] : tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding_3:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding +// CHECK-NEXT: [[vsharding_annotated_4:%.*]] = mesh.shard [[vsharding_annotated]] to [[vsharding_3]] annotate_for_users : tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding_5:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding +// CHECK-NEXT: [[vsharding_annotated_6:%.*]] = mesh.shard [[vsharding_annotated_2]] to [[vsharding_5]] annotate_for_users : tensor<1024x1024xf32> +// CHECK-NEXT: [[v1:%.*]] = linalg.add ins([[vsharding_annotated_4]], [[vcst_0]] : tensor<1024x1024xf32>, f32) outs([[vsharding_annotated_6]] : tensor<1024x1024xf32>) -> tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding_7:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding +// CHECK-NEXT: [[vsharding_annotated_8:%.*]] = mesh.shard [[v1]] to [[vsharding_7]] : tensor<1024x1024xf32> +func.func @test_shard_constant_back() -> (tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} { + %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + %ci = arith.constant 43.4e+00 : f32 + %o1 = tensor.empty() : tensor<1024x1024xf32> + %res = linalg.add ins(%cst_1, %ci : tensor<1024x1024xf32>, f32) outs(%o1 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32> + %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding + %sharding_annotated_1 = mesh.shard %res to %sharding_1 : tensor<1024x1024xf32> + return %sharding_annotated_1 : tensor<1024x1024xf32> +} From f41780fec565c7da7c59a7e04d5d12b1826a268b Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Wed, 4 Dec 2024 11:04:59 +0100 Subject: [PATCH 02/10] better handling of replicated tensors --- mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h | 2 +- .../Mesh/Interfaces/ShardingInterface.h | 4 ++- mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 22 +++++++++--- .../Mesh/Interfaces/ShardingInterface.cpp | 36 +++++++++++-------- .../Dialect/Mesh/Transforms/Spmdization.cpp | 3 +- 5 files changed, 45 insertions(+), 22 deletions(-) diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h index 210b82151ede4..626f2fcf93b36 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h +++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h @@ -51,7 +51,7 @@ class MeshSharding { SmallVector dynamic_sharded_dims_offsets; public: - MeshSharding() = default; + MeshSharding(::mlir::FlatSymbolRefAttr mesh_ = nullptr); MeshSharding(Value rhs); static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef split_axes_, diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h index b4d25cef05a7b..14aad7f9f6783 100644 --- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h +++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h @@ -36,7 +36,9 @@ struct ShardingOption { bool empty = false; ShardingOption() = default; ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr mesh) - : shardingArray(std::move(shardingArray)), mesh(mesh) {} + : shardingArray(std::move(shardingArray)), mesh(mesh) { + assert(this->mesh); + } static ShardingOption makeEmpty() { auto res = ShardingOption(); res.empty = true; diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp index 352bf476e3f57..5e342a855d6ae 100644 --- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp +++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp @@ -194,6 +194,12 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef shardedDimsOffsets = {}, ArrayRef haloSizes = {}) { + // 0d tensors cannot be sharded and must get replicated + if (inShape.empty()) { + assert(outShape.empty()); + return; + } + std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape), llvm::adl_begin(outShape)); @@ -318,7 +324,12 @@ void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding, Value operandValue = operand.get(); Operation *operandSrcOp = operandValue.getDefiningOp(); bool isBlockArg = !operandSrcOp; - if(!isBlockArg && operandSrcOp->hasTrait()) { + { + auto opType = dyn_cast(operandValue.getType()); + assert(!opType || opType.getRank() > 0 || isFullReplication(sharding)); + } + if (!isa(operandValue.getType()) && operandSrcOp && + operandSrcOp->hasTrait()) { return; } @@ -711,13 +722,15 @@ bool MeshSharding::operator!=(const MeshSharding &rhs) const { return !(*this == rhs); } +MeshSharding::MeshSharding(::mlir::FlatSymbolRefAttr mesh_) : mesh(mesh_) {} + MeshSharding::MeshSharding(Value rhs) { auto shardingOp = mlir::dyn_cast(rhs.getDefiningOp()); assert(shardingOp && "expected sharding op"); auto splitAxes = shardingOp.getSplitAxes().getAxes(); auto partialAxes = shardingOp.getPartialAxes().value_or(ArrayRef()); if(splitAxes.empty() && partialAxes.empty()) { - *this = MeshSharding(); + *this = MeshSharding(shardingOp.getMeshAttr()); return; } *this = get(shardingOp.getMeshAttr(), splitAxes, partialAxes, @@ -736,12 +749,11 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef static_sharded_dims_offsets_, ArrayRef dynamic_halo_sizes_, ArrayRef dynamic_sharded_dims_offsets_) { + MeshSharding res(mesh_); if(split_axes_.empty() && partial_axes_.empty()) { - return MeshSharding(); + return res; } - MeshSharding res; - res.mesh = mesh_; res.split_axes.resize(split_axes_.size()); for (auto [i, axis] : llvm::enumerate(split_axes_)) { res.split_axes[i] = diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp index aae2d4ccfeed9..aaffe759b0cef 100644 --- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp +++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp @@ -286,18 +286,22 @@ mesh::detail::defaultGetShardingOption(Operation *op, continue; AffineMap map = maps[numOperands + shardingIt.index()]; anyShardingInResultsOrOperands = true; - // Handle the split axes: calculate the corresponding loop index for each - // split axes sub-array, and then store the sub-array to - // shardingOption[index] - for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) { - AffineExpr expr = std::get<0>(it); - ArrayRef axes = std::get<1>(it).asArrayRef(); - auto dim = cast(expr); - unsigned index = dim.getPosition(); - visitedLoopIndices.insert(index); - if (failed(fillShardingOption(op, shardingOption, shardAttr.getMeshAttr(), - axes, index))) - return failure(); + if (shardAttr.getSplitAxes().empty() || map.getResults().empty()) { + shardingOption.mesh = shardAttr.getMeshAttr(); + } else { + // Handle the split axes: calculate the corresponding loop index for each + // split axes sub-array, and then store the sub-array to + // shardingOption[index] + for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) { + AffineExpr expr = std::get<0>(it); + ArrayRef axes = std::get<1>(it).asArrayRef(); + auto dim = cast(expr); + unsigned index = dim.getPosition(); + visitedLoopIndices.insert(index); + if (failed(fillShardingOption(op, shardingOption, + shardAttr.getMeshAttr(), axes, index))) + return failure(); + } } // Handle the partial axes: at this stage, the exact loop index/indices @@ -323,7 +327,7 @@ mesh::detail::defaultGetShardingOption(Operation *op, if (!shardAttr) continue; - anyShardingInResultsOrOperands = true; + anyShardingInResultsOrOperands = !shardAttr.getSplitAxes().empty(); AffineMap map = maps[shardingIt.index()]; unsigned numDims = map.getNumDims(); @@ -454,6 +458,10 @@ static FailureOr getSharding(OpOperand &opOperand, return MeshSharding(); return failure(); } + // 0d tensors cannot be sharded and must get replicated + if (operandType.getRank() == 0) { + return MeshSharding(shardingOption.mesh); + } SmallVector> splitAxes(operandType.getRank()); unsigned numDims = map.getNumDims(); for (auto it : llvm::enumerate(map.getResults())) { @@ -584,7 +592,7 @@ static bool isValueCompatibleWithFullReplicationSharding(Value value, MeshSharding sharding) { if (isa(value.getType())) { - return sharding && isFullReplication(sharding); + return isFullReplication(sharding); } return !sharding; diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp index 04932f11e6b43..27297a8be5d06 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp @@ -561,7 +561,8 @@ TypedValue reshard(ImplicitLocOpBuilder &builder, MeshOp mesh, TypedValue sourceUnshardedValue, TypedValue sourceShard) { // If source and destination sharding are the same, no need to do anything. - if (sourceSharding == targetSharding) { + if (sourceSharding == targetSharding || (isFullReplication(sourceSharding) && + isFullReplication(targetSharding))) { return sourceShard; } From c1324d3868f7a1926a4754823fc269ae8226aade Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Thu, 5 Dec 2024 12:55:23 +0100 Subject: [PATCH 03/10] canonicalize ShardOp and ShardingOp --- mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 3 +- mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 87 +++++++++++++++++-- .../Dialect/Mesh/Transforms/Spmdization.cpp | 2 +- mlir/test/Dialect/Mesh/canonicalization.mlir | 40 ++++++++- 4 files changed, 122 insertions(+), 10 deletions(-) diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td index 6039e61a93fad..531020930768e 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td +++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td @@ -28,7 +28,7 @@ class Mesh_Op traits = []> : Op { } -def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol]> { +def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol, Pure]> { let summary = "Description of a device/process mesh."; let description = [{ The mesh.mesh operation is a symbol operation that identifies a specific @@ -460,6 +460,7 @@ def Mesh_ShardOp : Mesh_Op<"shard", [ (`annotate_for_users` $annotate_for_users^)? attr-dict `:` type($result) }]; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp index 5e342a855d6ae..6a1498c0f6814 100644 --- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp +++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp @@ -594,9 +594,10 @@ LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) { namespace { // Sharding annotations "halo sizes" and "sharded dims offsets" // are a mix of attributes and dynamic values. This canonicalization moves -// constant values to the respective attribute lists and so minimizes the number +// constant values to the respective attribute lists, minimizing the number // of values. -class FoldDynamicLists final : public OpRewritePattern { +// It also removes sharded_dims_sizes and halos if they are effectively "empty". +class NormalizeSharding final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -608,14 +609,39 @@ class FoldDynamicLists final : public OpRewritePattern { op.getDynamicShardedDimsOffsets(), b); // No constant operands were folded, just return; - if (failed(foldDynamicIndexList(mixedHalos, /*onlyNonNegative=*/true)) && - failed(foldDynamicIndexList(mixedOffs, /*onlyNonNegative=*/true))) { - return failure(); - } + bool modified = succeeded(foldDynamicIndexList(mixedHalos, true)) || + succeeded(foldDynamicIndexList(mixedOffs, true)); auto halos = decomposeMixedValues(mixedHalos); auto offs = decomposeMixedValues(mixedOffs); + if (halos.second.empty() && !halos.first.empty()) { + if (halos.first[0] == 0 && llvm::all_equal(halos.first)) { + halos.first.clear(); + modified = true; + } + } + + if (offs.second.empty() && !offs.first.empty()) { + assert(offs.first.size() >= 2); + auto diff = offs.first[1] - offs.first[0]; + bool all_same = offs.first.size() > 2; + for (auto i = 2u; i < offs.first.size(); ++i) { + if (offs.first[i] - offs.first[i - 1] != diff) { + all_same = false; + break; + } + } + if (all_same) { + offs.first.clear(); + modified = true; + } + } + + if (!modified) { + return failure(); + } + op.setStaticHaloSizes(halos.first); op.getDynamicHaloSizesMutable().assign(halos.second); op.setStaticShardedDimsOffsets(offs.first); @@ -628,7 +654,7 @@ class FoldDynamicLists final : public OpRewritePattern { void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results, mlir::MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// @@ -796,6 +822,53 @@ void ShardOp::getAsmResultNames( setNameFn(getResult(), "sharding_annotated"); } +namespace { +// Determine if the given ShardOp is a duplicate of another ShardOp +// on the same value. This can happen if constant values are sharded. +class FoldDuplicateShardOp final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ShardOp op, PatternRewriter &b) const override { + // Get the use-list of the value being sharded and check if it has more than + // one use. + Value value = op.getSrc(); + if (value.hasOneUse() || value.getDefiningOp()) { + return failure(); + } + + // Iterate through the uses of the value to find a duplicate ShardOp. + for (auto &use : value.getUses()) { + if (use.getOwner() != op.getOperation()) { + auto otherOp = dyn_cast(use.getOwner()); + if (!otherOp || !otherOp->isBeforeInBlock(op)) { + return failure(); + } + // Create a MeshSharding object for the current and the other ShardOp + // If the two are equal replace current op with the other op. + MeshSharding currentSharding(op.getSharding()); + MeshSharding otherSharding(otherOp.getSharding()); + if (currentSharding == otherSharding) { + b.replaceAllUsesWith(op.getResult(), otherOp.getResult()); + b.eraseOp(op.getOperation()); + } else { + // use the other sharding as input for op + op.getSrcMutable().assign(otherOp.getResult()); + } + return success(); + } + } + + return failure(); + } +}; +} // namespace + +void ShardOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results, + mlir::MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // mesh.process_multi_index op //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp index 27297a8be5d06..e6fe0fd5d1e87 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp @@ -693,7 +693,7 @@ static std::vector getResultShardings(Operation &op) { [&op](OpResult result) { TypedValue rankedTensor = dyn_cast>(result); - if (!rankedTensor || op.hasTrait()) { + if (!rankedTensor) { return MeshSharding(); } if (!result.hasOneUse()) { diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir index f0112d689805d..aff07bbf8a214 100644 --- a/mlir/test/Dialect/Mesh/canonicalization.mlir +++ b/mlir/test/Dialect/Mesh/canonicalization.mlir @@ -207,4 +207,42 @@ func.func @test_shard_offs() -> !mesh.sharding { // CHECK mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, 2, 3, 4, 0, 2, 3, 4, 22] : !mesh.sharding %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, %c2_i64, 3, 4, 0, %c2_i64, 3, 4, 22] : !mesh.sharding return %sharding : !mesh.sharding -} \ No newline at end of file +} + +// CHECK-LABEL: func @test_duplicate_shardops +func.func @test_duplicate_shardops() -> (tensor<1024x1024xf32>, tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} { + // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + // CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0, 1]] : !mesh.sharding + %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding + %cst_2 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + %sharding_2 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding + %sharding_annotated_2 = mesh.shard %cst_2 to %sharding_2 : tensor<1024x1024xf32> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + %sharding_3 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding + %sharding_annotated_3 = mesh.shard %cst_3 to %sharding_3 : tensor<1024x1024xf32> + // CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32> + %sharding_annotated_1 = mesh.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32> + // CHECK-NEXT: return [[vsharding_annotated]], [[vsharding_annotated]] : tensor<1024x1024xf32>, tensor<1024x1024xf32> + return %sharding_annotated_1, %sharding_annotated_2 : tensor<1024x1024xf32>, tensor<1024x1024xf32> +} + +// CHECK-LABEL: func @test_duplicate_shardops_diff +func.func @test_duplicate_shardops_diff() -> (tensor<1024x1024xf32>, tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} { + // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + // CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding + %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding + %cst_2 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + // CHECK-NEXT: [[vsharding_0:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0, 1]] : !mesh.sharding + %sharding_2 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding + // CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding_0]] : tensor<1024x1024xf32> + %sharding_annotated_2 = mesh.shard %cst_2 to %sharding_2 : tensor<1024x1024xf32> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + %sharding_3 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding + %sharding_annotated_3 = mesh.shard %cst_3 to %sharding_3 : tensor<1024x1024xf32> + // CHECK-NEXT: [[vsharding_annotated_1:%.*]] = mesh.shard [[vsharding_annotated]] to [[vsharding]] : tensor<1024x1024xf32> + %sharding_annotated_1 = mesh.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32> + // CHECK-NEXT: return [[vsharding_annotated_1]], [[vsharding_annotated]] : tensor<1024x1024xf32>, tensor<1024x1024xf32> + return %sharding_annotated_1, %sharding_annotated_2 : tensor<1024x1024xf32>, tensor<1024x1024xf32> +} From 1d861861719b69736c7bf3952ca76b5d439cfcd4 Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Thu, 19 Dec 2024 13:19:32 +0100 Subject: [PATCH 04/10] sharding propagation: add only one shardop for each result --- mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h | 9 ++++-- mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 33 +++++++++++++-------- 2 files changed, 26 insertions(+), 16 deletions(-) diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h index 626f2fcf93b36..7de7842baf98a 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h +++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h @@ -201,10 +201,13 @@ ShapedType shardShapedType(ShapedType shape, MeshOp mesh, Type shardType(Type type, MeshOp mesh, MeshSharding sharding); // Insert shard op if there is not one that already has the same sharding. +// Use newShardOp if it is not null. Otherwise create a new one. // May insert resharding if required. -void maybeInsertTargetShardingAnnotation(MeshSharding sharding, - OpOperand &operand, - OpBuilder &builder); +// Return the target ShardOP (new or existing). +ShardOp maybeInsertTargetShardingAnnotation(MeshSharding sharding, + OpOperand &operand, + OpBuilder &builder, + ShardOp newShardOp); void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result, OpBuilder &builder); void maybeInsertSourceShardingAnnotation(MeshSharding sharding, diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp index 6a1498c0f6814..2fff67c44a8ac 100644 --- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp +++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp @@ -275,9 +275,10 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) { return type; } -void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, - OpOperand &operand, - OpBuilder &builder) { +ShardOp mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, + OpOperand &operand, + OpBuilder &builder, + ShardOp newShardOp) { OpBuilder::InsertionGuard insertionGuard(builder); Value operandValue = operand.get(); Operation *operandOp = operand.getOwner(); @@ -286,13 +287,16 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, if (shardOp && sharding == shardOp.getSharding() && !shardOp.getAnnotateForUsers()) { // No need for anything the correct sharding is already set. - return; + return newShardOp ? newShardOp : shardOp; } - auto shardingOp = builder.create(operandValue.getLoc(), sharding); - auto newShardOp = - builder.create(operandValue.getLoc(), operandValue, shardingOp, - /*annotate_for_users*/ false); + if (!newShardOp) { + auto shardingOp = + builder.create(operandValue.getLoc(), sharding); + newShardOp = + builder.create(operandValue.getLoc(), operandValue, shardingOp, + /*annotate_for_users*/ false); + } IRRewriter rewriter(builder); rewriter.replaceUsesWithIf( operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) { @@ -300,20 +304,23 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, }); if (!shardOp || shardOp.getAnnotateForUsers()) { - return; + return newShardOp; } - auto newShardOp2 = - builder.create(operandValue.getLoc(), newShardOp, shardingOp, - /*annotate_for_users*/ true); + auto newShardOp2 = builder.create(operandValue.getLoc(), newShardOp, + newShardOp.getSharding(), + /*annotate_for_users*/ true); rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2); + return newShardOp; } void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result, OpBuilder &builder) { + ShardOp newShardOp; for (auto &use : llvm::make_early_inc_range(result.getUses())) { - maybeInsertTargetShardingAnnotation(sharding, use, builder); + newShardOp = + maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp); } } From 70fb9a5a04e8443bc18f70782074f90c30b03e2a Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Mon, 13 Jan 2025 15:36:54 +0100 Subject: [PATCH 05/10] Adding sharding extraction operation and op tests and handling GetShardingOp in ShardingPropagation --- mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 21 ++++++++++++++++++ mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 22 ++++++++++++++----- .../Mesh/Transforms/ShardingPropagation.cpp | 2 +- .../Dialect/Mesh/Transforms/Spmdization.cpp | 9 ++++++++ mlir/test/Dialect/Mesh/ops.mlir | 10 +++++++++ mlir/test/Dialect/Mesh/spmdization.mlir | 14 ++++++++++++ 6 files changed, 71 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td index 531020930768e..031e6f63bcb42 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td +++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td @@ -318,12 +318,33 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [ "ArrayRef":$split_axes, "::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes, "::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_offsets)>, + OpBuilder<(ins "llvm::StringRef":$mesh, + "ArrayRef":$split_axes, + CArg<"ArrayRef", "{}">:$static_halo_sizes, + CArg<"ArrayRef", "{}">:$static_sharded_dims_offsets + )>, OpBuilder<(ins "mlir::mesh::MeshSharding":$from)> ]; let hasVerifier = 1; let hasCanonicalizer = 1; } +def Mesh_GetShardingOp : Mesh_Op<"get_sharding", [Pure]> { + let summary = "Get the sharding of the given tensor."; + let description = [{ + This operation returns the sharding of the given tensor as a MeshSharding. + }]; + let arguments = (ins + AnyRankedTensor:$source + ); + let results = (outs + Mesh_Sharding:$result + ); + let assemblyFormat = [{ + $source attr-dict `:` type($source) `->` type($result) + }]; +} + def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> { let summary = "Get the shard shape of a given process/device."; let description = [{ diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp index 2fff67c44a8ac..f84d467048522 100644 --- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp +++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp @@ -454,16 +454,14 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, ArrayRef split_axes, ArrayRef partial_axes, mesh::ReductionKind partial_type, - ArrayRef static_halo_sizes, - ArrayRef static_sharded_dims_offsets) { + ArrayRef static_halos, + ArrayRef static_offsets) { return build( b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), ::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes), ::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type), - ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halo_sizes), {}, - ::mlir::DenseI64ArrayAttr::get(b.getContext(), - static_sharded_dims_offsets), - {}); + ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {}, + ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {}); } void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, @@ -475,6 +473,18 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, {}, {}, {}, {}); } +void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, + llvm::StringRef mesh, ArrayRef split_axes, + ArrayRef static_halos, + ArrayRef static_offsets) { + return build( + b, odsState, FlatSymbolRefAttr::get(b.getContext(), mesh), + MeshAxesArrayAttr::get(b.getContext(), split_axes), {}, + ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum), + ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {}, + ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {}); +} + void ShardingOp::build( ::mlir::OpBuilder &b, ::mlir::OperationState &odsState, FlatSymbolRefAttr mesh, ArrayRef split_axes, diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp index f96d54424a2fe..8c989cce63406 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp @@ -285,7 +285,7 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) { ShardingInterface shardingOp = llvm::dyn_cast(op); if (op->hasTrait() || (op->hasTrait() && !shardingOp) || - llvm::isa(op)) + llvm::isa(op)) return success(); if (!shardingOp) { diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp index e6fe0fd5d1e87..4ec8bbc0dff7d 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp @@ -738,6 +738,15 @@ spmdizeOperation(Operation &op, IRMapping &spmdizationMap, if (isa(op)) { return success(); } + if (auto getShardingOp = dyn_cast(op)) { + auto shardOp = getShardingOp.getSource().getDefiningOp(); + if (!shardOp) { + return op.emitError("expected a shard op as source of get_sharding"); + } + auto newSharding = builder.clone(*shardOp.getSharding().getDefiningOp()); + spmdizationMap.map(op.getResult(0), newSharding->getResult(0)); + return success(); + } ShardOp shardOp = llvm::dyn_cast(op); if (shardOp) { diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir index 978de4939ee77..dae21655afb23 100644 --- a/mlir/test/Dialect/Mesh/ops.mlir +++ b/mlir/test/Dialect/Mesh/ops.mlir @@ -164,6 +164,16 @@ func.func @mesh_shard_shape() { return } +// CHECK-LABEL: func @mesh_get_sharding +// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32> +func.func @mesh_get_sharding(%arg0 : tensor<4x8xf32>) -> !mesh.sharding { + // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh1 split_axes = {{\[\[}}], [0]] : !mesh.sharding + %s = mesh.sharding @mesh1 split_axes = [[], [0]] : !mesh.sharding + // CHECK-NEXT: mesh.get_sharding %[[ARG]] : tensor<4x8xf32> -> !mesh.sharding + %0 = mesh.get_sharding %arg0 : tensor<4x8xf32> -> !mesh.sharding + return %0 : !mesh.sharding +} + // CHECK-LABEL: func @mesh_shape func.func @mesh_shape() -> (index, index) { // CHECK: %[[RES:.*]]:2 = mesh.mesh_shape @mesh0 axes = [0, 1] : index, index diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir index c1b96fda0f4a7..59f7162e21013 100644 --- a/mlir/test/Dialect/Mesh/spmdization.mlir +++ b/mlir/test/Dialect/Mesh/spmdization.mlir @@ -4,6 +4,20 @@ mesh.mesh @mesh_1d(shape = 2) +// CHECK-LABEL: func @return_sharding +func.func @return_sharding( + // CHECK-SAME: [[ARG:%.*]]: tensor<1xf32> + %arg0: tensor<2xf32> +// CHECK-SAME: ) -> (tensor<1xf32>, !mesh.sharding) { +) -> (tensor<2xf32>, !mesh.sharding) { + %ssharding_annotated = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding + %sharding_annotated = mesh.shard %arg0 to %ssharding_annotated : tensor<2xf32> + // CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}0]] : !mesh.sharding + %r = mesh.get_sharding %sharding_annotated : tensor<2xf32> -> !mesh.sharding + // CHECK-NEXT: return [[ARG]], [[vsharding]] : tensor<1xf32>, !mesh.sharding + return %sharding_annotated, %r : tensor<2xf32>, !mesh.sharding +} + // CHECK-LABEL: func @full_replication func.func @full_replication( // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> From 6f42ee5ed6a7b9ef414a9500bdd73da368539930 Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Tue, 28 Jan 2025 10:39:46 +0100 Subject: [PATCH 06/10] comments adding libs clang-format renaming mesh-spmdize.cpp -> mesh-spmdize.mlir and fixing format --- .../Dialect/Arith/Transforms/CMakeLists.txt | 2 ++ .../Arith/Transforms/ShardingInterfaceImpl.cpp | 2 +- mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 9 ++++++--- .../Mesh/Interfaces/ShardingInterface.cpp | 9 ++------- .../Dialect/Mesh/Transforms/Spmdization.cpp | 2 +- .../Extensions/MeshShardingExtensions.cpp | 18 +++++++++++------- mlir/test/Dialect/Arith/mesh-spmdize.cpp | 17 ----------------- mlir/test/Dialect/Arith/mesh-spmdize.mlir | 17 +++++++++++++++++ 8 files changed, 40 insertions(+), 36 deletions(-) delete mode 100644 mlir/test/Dialect/Arith/mesh-spmdize.cpp create mode 100644 mlir/test/Dialect/Arith/mesh-spmdize.mlir diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt index 30dd84aff120f..f96bda603baa6 100644 --- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt @@ -27,7 +27,9 @@ add_mlir_dialect_library(MLIRArithTransforms MLIRInferIntRangeInterface MLIRIR MLIRMemRefDialect + MLIRMeshDialect MLIRPass + MLIRShardingInterface MLIRTensorDialect MLIRTransforms MLIRTransformUtils diff --git a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp index fc033294eb01b..f31db49067756 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp @@ -19,7 +19,7 @@ using namespace mlir::mesh; namespace { -// Sharding of arith.empty/arith.splat +// Sharding of arith.constant struct ConstantShardingInterface : public ShardingInterface::ExternalModel { diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp index f84d467048522..c789fc527e3f6 100644 --- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp +++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp @@ -286,7 +286,7 @@ ShardOp mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, ShardOp shardOp = dyn_cast(operandOp); if (shardOp && sharding == shardOp.getSharding() && !shardOp.getAnnotateForUsers()) { - // No need for anything the correct sharding is already set. + // No need for anything if the correct sharding is already set. return newShardOp ? newShardOp : shardOp; } @@ -639,6 +639,8 @@ class NormalizeSharding final : public OpRewritePattern { } } + // Remove sharded dims offsets if they are effectively the default values, + // e.g. if they define equi-distance between all neighboring shards. if (offs.second.empty() && !offs.first.empty()) { assert(offs.first.size() >= 2); auto diff = offs.first[1] - offs.first[0]; @@ -772,7 +774,8 @@ MeshSharding::MeshSharding(Value rhs) { assert(shardingOp && "expected sharding op"); auto splitAxes = shardingOp.getSplitAxes().getAxes(); auto partialAxes = shardingOp.getPartialAxes().value_or(ArrayRef()); - if(splitAxes.empty() && partialAxes.empty()) { + // If splitAxes and partialAxes are empty, use "empty" constructor. + if (splitAxes.empty() && partialAxes.empty()) { *this = MeshSharding(shardingOp.getMeshAttr()); return; } @@ -793,7 +796,7 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef dynamic_halo_sizes_, ArrayRef dynamic_sharded_dims_offsets_) { MeshSharding res(mesh_); - if(split_axes_.empty() && partial_axes_.empty()) { + if (split_axes_.empty() && partial_axes_.empty()) { return res; } diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp index aaffe759b0cef..f427d004c558f 100644 --- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp +++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp @@ -174,11 +174,6 @@ LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() { if (!llvm::isa(type) && !type.isIntOrIndexOrFloat()) return failure(); - // check loop types - // SmallVector loopTypes = getLoopIteratorTypes(); - // if (loopTypes.empty()) - // return failure(); - // check maps SmallVector maps = getIndexingMaps(); if (maps.empty()) @@ -453,8 +448,8 @@ static FailureOr getSharding(OpOperand &opOperand, AffineMap map) { Value operandValue = opOperand.get(); auto operandType = dyn_cast(operandValue.getType()); - if(!operandType) { - if(operandValue.getType().isIntOrIndexOrFloat()) + if (!operandType) { + if (operandValue.getType().isIntOrIndexOrFloat()) return MeshSharding(); return failure(); } diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp index 4ec8bbc0dff7d..601af0200e785 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp @@ -690,7 +690,7 @@ static std::vector getResultShardings(Operation &op) { std::vector res; res.reserve(op.getNumResults()); llvm::transform(op.getResults(), std::back_inserter(res), - [&op](OpResult result) { + [](OpResult result) { TypedValue rankedTensor = dyn_cast>(result); if (!rankedTensor) { diff --git a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp index 6bb5d4a66f39e..b2acbf20b3fb9 100644 --- a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp +++ b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp @@ -23,9 +23,10 @@ using namespace mlir::mesh; namespace { // Sharding of tensor.empty/tensor.splat -template +template struct CreatorOpShardingInterface - : public ShardingInterface::ExternalModel, OpTy> { + : public ShardingInterface::ExternalModel, + OpTy> { SmallVector getLoopIteratorTypes(Operation *op) const { auto ndims = mlir::cast(op->getResult(0).getType()).getRank(); return SmallVector(ndims, @@ -38,7 +39,9 @@ struct CreatorOpShardingInterface auto type = dyn_cast(val.getType()); if (!type) return {}; - return SmallVector(op->getNumOperands() + op->getNumResults(), {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)}); + return SmallVector( + op->getNumOperands() + op->getNumResults(), + {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)}); } LogicalResult spmdize(Operation *op, ArrayRef spmdizedOperands, @@ -82,8 +85,7 @@ struct CreatorOpShardingInterface newOperands.emplace_back(spmdizedOperands[++currOldOprndNum]); } } - newOp = - builder.create(op->getLoc(), shardType, newOperands); + newOp = builder.create(op->getLoc(), shardType, newOperands); spmdizationMap.map(op->getResult(0), newOp->getResult(0)); } else { // `clone` will populate the mapping of old to new results. @@ -100,7 +102,9 @@ void mlir::tensor::registerShardingInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) { - EmptyOp::template attachInterface>(*ctx); - SplatOp::template attachInterface>(*ctx); + EmptyOp::template attachInterface>( + *ctx); + SplatOp::template attachInterface>( + *ctx); }); } diff --git a/mlir/test/Dialect/Arith/mesh-spmdize.cpp b/mlir/test/Dialect/Arith/mesh-spmdize.cpp deleted file mode 100644 index 0688e14b1cf72..0000000000000 --- a/mlir/test/Dialect/Arith/mesh-spmdize.cpp +++ /dev/null @@ -1,17 +0,0 @@ -// RUN: mlir-opt \ -// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization))" \ -// RUN: %s | FileCheck %s - -mesh.mesh @mesh4x4(shape = 4x4) - -// CHECK-LABEL: func @test_spmdize_constant -// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<256x1024xf32> -// CHECK-NEXT: [[vc434_i32:%.*]] = arith.constant 434 : i32 -// CHECK-NEXT: return [[vcst]] : tensor<256x1024xf32> -func.func @test_spmdize_constant() -> (tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} { - %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> - %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding - %sharding_annotated_1 = mesh.shard %cst to %sharding_1 : tensor<1024x1024xf32> - %ci = arith.constant 434 : i32 - return %sharding_annotated_1 : tensor<1024x1024xf32> -} diff --git a/mlir/test/Dialect/Arith/mesh-spmdize.mlir b/mlir/test/Dialect/Arith/mesh-spmdize.mlir new file mode 100644 index 0000000000000..6b55dd533a92c --- /dev/null +++ b/mlir/test/Dialect/Arith/mesh-spmdize.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-opt \ +// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization))" \ +// RUN: %s | FileCheck %s + +mesh.mesh @mesh4x4(shape = 4x4) + +// CHECK-LABEL: func @test_spmdize_constant +// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : +// tensor<256x1024xf32> CHECK-NEXT: [[vc434_i32:%.*]] = arith.constant 434 : +// i32 CHECK-NEXT: return [[vcst]] : tensor<256x1024xf32> +func.func @test_spmdize_constant() ->(tensor<1024x1024xf32>)attributes{llvm.emit_c_interface} { + %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding + %sharding_annotated_1 = mesh.shard %cst to %sharding_1 : tensor<1024x1024xf32> + %ci = arith.constant 434 : i32 + return %sharding_annotated_1 : tensor<1024x1024xf32> +} From ec4a18ee34ff54084ac00666c7b77001cf8026d7 Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Tue, 11 Feb 2025 12:46:44 +0100 Subject: [PATCH 07/10] assert expected ArrayRerf argument size --- .../Arith/Transforms/ShardingInterfaceImpl.cpp | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp index f31db49067756..ff1625877efcb 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp @@ -46,18 +46,20 @@ struct ConstantShardingInterface FailureOr getShardingOption(Operation *op, ArrayRef operandShardings, ArrayRef resultShardings) const { - if (!resultShardings[0]) { + assert(resultShardings.size() == 1 && + "Expecting exactly one result sharding for arith.constant"); + auto resultSharding = resultShardings[0]; + if (!resultSharding) { return failure(); } if (auto type = dyn_cast(op->getResult(0).getType())) { - ShardingArray axesArray(resultShardings[0].getSplitAxes().size()); - for (auto [i, axes] : - llvm::enumerate(resultShardings[0].getSplitAxes())) { + ShardingArray axesArray(resultSharding.getSplitAxes().size()); + for (auto [i, axes] : llvm::enumerate(resultSharding.getSplitAxes())) { axesArray[i].append(axes.asArrayRef().begin(), axes.asArrayRef().end()); } - return ShardingOption(axesArray, resultShardings[0].getMeshAttr()); + return ShardingOption(axesArray, resultSharding.getMeshAttr()); } - return ShardingOption({}, resultShardings[0].getMeshAttr()); + return ShardingOption({}, resultSharding.getMeshAttr()); } LogicalResult spmdize(Operation *op, ArrayRef spmdizedOperands, @@ -67,8 +69,7 @@ struct ConstantShardingInterface SymbolTableCollection &symbolTable, OpBuilder &builder) const { auto cOp = cast(op); - auto value = dyn_cast(cOp.getValue()); - if (value) { + if (auto value = dyn_cast(cOp.getValue())) { if (!value.isSplat() || !resultShardings[0]) { // Currently non-splat constants are not supported. return failure(); From e8ccad120aaf9921c0c6ea29b8938085d59ff613 Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Tue, 11 Feb 2025 13:12:46 +0100 Subject: [PATCH 08/10] added sharding exmpample --- mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp index ff1625877efcb..62d137a4cfb0e 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp @@ -20,6 +20,11 @@ using namespace mlir::mesh; namespace { // Sharding of arith.constant +// RankedTensor constants can be sharded like any other tensor. +// %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> +// %sharding = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding +// Scalar constants are always replicated and need no sharding annotation. + struct ConstantShardingInterface : public ShardingInterface::ExternalModel { From 99cf24e1e7f477ac826431be6092179f12ac01a0 Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Wed, 12 Feb 2025 11:01:57 +0100 Subject: [PATCH 09/10] comments and nicer code (from review) --- mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 35 +++++++++++++++------------- mlir/test/Dialect/Mesh/ops.mlir | 2 -- 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp index c789fc527e3f6..561b1ef3b1c39 100644 --- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp +++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp @@ -629,30 +629,33 @@ class NormalizeSharding final : public OpRewritePattern { bool modified = succeeded(foldDynamicIndexList(mixedHalos, true)) || succeeded(foldDynamicIndexList(mixedOffs, true)); - auto halos = decomposeMixedValues(mixedHalos); - auto offs = decomposeMixedValues(mixedOffs); + auto [staticHalos, dynamicHalos] = decomposeMixedValues(mixedHalos); + auto [staticOffs, dynamicOffs] = decomposeMixedValues(mixedOffs); - if (halos.second.empty() && !halos.first.empty()) { - if (halos.first[0] == 0 && llvm::all_equal(halos.first)) { - halos.first.clear(); + if (dynamicHalos.empty() && !staticHalos.empty()) { + if (staticHalos[0] == 0 && llvm::all_equal(staticHalos)) { + staticHalos.clear(); modified = true; } } // Remove sharded dims offsets if they are effectively the default values, // e.g. if they define equi-distance between all neighboring shards. - if (offs.second.empty() && !offs.first.empty()) { - assert(offs.first.size() >= 2); - auto diff = offs.first[1] - offs.first[0]; - bool all_same = offs.first.size() > 2; - for (auto i = 2u; i < offs.first.size(); ++i) { - if (offs.first[i] - offs.first[i - 1] != diff) { + // Requires static-only offsets. Compares the first distance as the + // difference between the first two offsets. Only if all consecutive + // distances are the same, the offsets are removed. + if (dynamicOffs.empty() && !staticOffs.empty()) { + assert(staticOffs.size() >= 2); + auto diff = staticOffs[1] - staticOffs[0]; + bool all_same = staticOffs.size() > 2; + for (auto i = 2u; i < staticOffs.size(); ++i) { + if (staticOffs[i] - staticOffs[i - 1] != diff) { all_same = false; break; } } if (all_same) { - offs.first.clear(); + staticOffs.clear(); modified = true; } } @@ -661,10 +664,10 @@ class NormalizeSharding final : public OpRewritePattern { return failure(); } - op.setStaticHaloSizes(halos.first); - op.getDynamicHaloSizesMutable().assign(halos.second); - op.setStaticShardedDimsOffsets(offs.first); - op.getDynamicShardedDimsOffsetsMutable().assign(offs.second); + op.setStaticHaloSizes(staticHalos); + op.getDynamicHaloSizesMutable().assign(dynamicHalos); + op.setStaticShardedDimsOffsets(staticOffs); + op.getDynamicShardedDimsOffsetsMutable().assign(dynamicOffs); return success(); } diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir index dae21655afb23..43a75bf3d8040 100644 --- a/mlir/test/Dialect/Mesh/ops.mlir +++ b/mlir/test/Dialect/Mesh/ops.mlir @@ -167,8 +167,6 @@ func.func @mesh_shard_shape() { // CHECK-LABEL: func @mesh_get_sharding // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32> func.func @mesh_get_sharding(%arg0 : tensor<4x8xf32>) -> !mesh.sharding { - // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh1 split_axes = {{\[\[}}], [0]] : !mesh.sharding - %s = mesh.sharding @mesh1 split_axes = [[], [0]] : !mesh.sharding // CHECK-NEXT: mesh.get_sharding %[[ARG]] : tensor<4x8xf32> -> !mesh.sharding %0 = mesh.get_sharding %arg0 : tensor<4x8xf32> -> !mesh.sharding return %0 : !mesh.sharding From 3729a86b9f961d768b3bbf1109d2b4f2dc9ecd46 Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Wed, 12 Feb 2025 11:58:25 +0100 Subject: [PATCH 10/10] maybeInsertTargetShardingAnnotation accepting reference only --- mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h | 9 ++++----- mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 20 +++++++++++--------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h index 7de7842baf98a..fc5cfffea27a7 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h +++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h @@ -203,11 +203,10 @@ Type shardType(Type type, MeshOp mesh, MeshSharding sharding); // Insert shard op if there is not one that already has the same sharding. // Use newShardOp if it is not null. Otherwise create a new one. // May insert resharding if required. -// Return the target ShardOP (new or existing). -ShardOp maybeInsertTargetShardingAnnotation(MeshSharding sharding, - OpOperand &operand, - OpBuilder &builder, - ShardOp newShardOp); +// Potentially updates newShardOp. +void maybeInsertTargetShardingAnnotation(MeshSharding sharding, + OpOperand &operand, OpBuilder &builder, + ShardOp &newShardOp); void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result, OpBuilder &builder); void maybeInsertSourceShardingAnnotation(MeshSharding sharding, diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp index 561b1ef3b1c39..12e1ec6d717ea 100644 --- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp +++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp @@ -275,10 +275,10 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) { return type; } -ShardOp mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, - OpOperand &operand, - OpBuilder &builder, - ShardOp newShardOp) { +void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, + OpOperand &operand, + OpBuilder &builder, + ShardOp &newShardOp) { OpBuilder::InsertionGuard insertionGuard(builder); Value operandValue = operand.get(); Operation *operandOp = operand.getOwner(); @@ -287,7 +287,10 @@ ShardOp mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, if (shardOp && sharding == shardOp.getSharding() && !shardOp.getAnnotateForUsers()) { // No need for anything if the correct sharding is already set. - return newShardOp ? newShardOp : shardOp; + if (!newShardOp) { + newShardOp = shardOp; + } + return; } if (!newShardOp) { @@ -304,14 +307,14 @@ ShardOp mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, }); if (!shardOp || shardOp.getAnnotateForUsers()) { - return newShardOp; + return; } auto newShardOp2 = builder.create(operandValue.getLoc(), newShardOp, newShardOp.getSharding(), /*annotate_for_users*/ true); rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2); - return newShardOp; + return; } void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, @@ -319,8 +322,7 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpBuilder &builder) { ShardOp newShardOp; for (auto &use : llvm::make_early_inc_range(result.getUses())) { - newShardOp = - maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp); + maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp); } }