From 0b0c01e3740235fc94c3f3a9344e76e7977138f1 Mon Sep 17 00:00:00 2001 From: Frank Schlimbach Date: Wed, 12 Feb 2025 12:44:48 +0100 Subject: [PATCH] [MLIR][mesh] Mesh fixes (#124724) A collection of fixes to the mesh dialect - allow constants in sharding propagation/spmdization - fixes to tensor replication (e.g. 0d tensors) - improved canonicalization - sharding propagation incorrectly generated too many ShardOps New operation `mesh.GetShardOp` enables exchanging sharding information (like on function boundaries) --- .../Arith/Transforms/ShardingInterfaceImpl.h | 23 +++ mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h | 10 +- mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 24 ++- .../Mesh/Interfaces/ShardingInterface.h | 4 +- mlir/include/mlir/InitAllDialects.h | 2 + .../Dialect/Arith/Transforms/CMakeLists.txt | 3 + .../Transforms/ShardingInterfaceImpl.cpp | 105 ++++++++++ mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 189 +++++++++++++++--- .../Mesh/Interfaces/ShardingInterface.cpp | 52 +++-- .../Mesh/Transforms/ShardingPropagation.cpp | 5 +- .../Dialect/Mesh/Transforms/Spmdization.cpp | 43 ++-- .../Extensions/MeshShardingExtensions.cpp | 21 +- mlir/test/Dialect/Arith/mesh-spmdize.mlir | 17 ++ .../Dialect/Arith/sharding-propagation.mlir | 54 +++++ mlir/test/Dialect/Mesh/canonicalization.mlir | 40 +++- mlir/test/Dialect/Mesh/ops.mlir | 8 + mlir/test/Dialect/Mesh/spmdization.mlir | 14 ++ 17 files changed, 525 insertions(+), 89 deletions(-) create mode 100644 mlir/include/mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h create mode 100644 mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp create mode 100644 mlir/test/Dialect/Arith/mesh-spmdize.mlir create mode 100644 mlir/test/Dialect/Arith/sharding-propagation.mlir diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h new file mode 100644 index 0000000000000..5addffbe571be --- /dev/null +++ b/mlir/include/mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h @@ -0,0 +1,23 @@ +//===- ShardingInterfaceImpl.h - ------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_ +#define MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_ + +namespace mlir { + +class DialectRegistry; + +namespace arith { + +void registerShardingInterfaceExternalModels(DialectRegistry ®istry); + +} // namespace arith +} // namespace mlir + +#endif // MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_ diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h index 75cb096130ca6..fc5cfffea27a7 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h +++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h @@ -51,7 +51,7 @@ class MeshSharding { SmallVector dynamic_sharded_dims_offsets; public: - MeshSharding() = default; + MeshSharding(::mlir::FlatSymbolRefAttr mesh_ = nullptr); MeshSharding(Value rhs); static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef split_axes_, @@ -62,7 +62,7 @@ class MeshSharding { ArrayRef dynamic_halo_sizes_ = {}, ArrayRef dynamic_sharded_dims_offsets_ = {}); ::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; } - ::llvm::StringRef getMesh() const { return mesh.getValue(); } + ::llvm::StringRef getMesh() const { return mesh ? mesh.getValue() : ""; } ArrayRef getSplitAxes() const { return split_axes; } ArrayRef getPartialAxes() const { return partial_axes; } ReductionKind getPartialType() const { return partial_type; } @@ -201,10 +201,12 @@ ShapedType shardShapedType(ShapedType shape, MeshOp mesh, Type shardType(Type type, MeshOp mesh, MeshSharding sharding); // Insert shard op if there is not one that already has the same sharding. +// Use newShardOp if it is not null. Otherwise create a new one. // May insert resharding if required. +// Potentially updates newShardOp. void maybeInsertTargetShardingAnnotation(MeshSharding sharding, - OpOperand &operand, - OpBuilder &builder); + OpOperand &operand, OpBuilder &builder, + ShardOp &newShardOp); void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result, OpBuilder &builder); void maybeInsertSourceShardingAnnotation(MeshSharding sharding, diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td index 6039e61a93fad..031e6f63bcb42 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td +++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td @@ -28,7 +28,7 @@ class Mesh_Op traits = []> : Op { } -def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol]> { +def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol, Pure]> { let summary = "Description of a device/process mesh."; let description = [{ The mesh.mesh operation is a symbol operation that identifies a specific @@ -318,12 +318,33 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [ "ArrayRef":$split_axes, "::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes, "::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_offsets)>, + OpBuilder<(ins "llvm::StringRef":$mesh, + "ArrayRef":$split_axes, + CArg<"ArrayRef", "{}">:$static_halo_sizes, + CArg<"ArrayRef", "{}">:$static_sharded_dims_offsets + )>, OpBuilder<(ins "mlir::mesh::MeshSharding":$from)> ]; let hasVerifier = 1; let hasCanonicalizer = 1; } +def Mesh_GetShardingOp : Mesh_Op<"get_sharding", [Pure]> { + let summary = "Get the sharding of the given tensor."; + let description = [{ + This operation returns the sharding of the given tensor as a MeshSharding. + }]; + let arguments = (ins + AnyRankedTensor:$source + ); + let results = (outs + Mesh_Sharding:$result + ); + let assemblyFormat = [{ + $source attr-dict `:` type($source) `->` type($result) + }]; +} + def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> { let summary = "Get the shard shape of a given process/device."; let description = [{ @@ -460,6 +481,7 @@ def Mesh_ShardOp : Mesh_Op<"shard", [ (`annotate_for_users` $annotate_for_users^)? attr-dict `:` type($result) }]; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h index b4d25cef05a7b..14aad7f9f6783 100644 --- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h +++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h @@ -36,7 +36,9 @@ struct ShardingOption { bool empty = false; ShardingOption() = default; ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr mesh) - : shardingArray(std::move(shardingArray)), mesh(mesh) {} + : shardingArray(std::move(shardingArray)), mesh(mesh) { + assert(this->mesh); + } static ShardingOption makeEmpty() { auto res = ShardingOption(); res.empty = true; diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h index 0da82825c8287..33bc89279c08c 100644 --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -23,6 +23,7 @@ #include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h" #include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h" #include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" #include "mlir/Dialect/ArmSME/IR/ArmSME.h" #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" @@ -158,6 +159,7 @@ inline void registerAllDialects(DialectRegistry ®istry) { arith::registerBufferDeallocationOpInterfaceExternalModels(registry); arith::registerBufferizableOpInterfaceExternalModels(registry); arith::registerBufferViewFlowOpInterfaceExternalModels(registry); + arith::registerShardingInterfaceExternalModels(registry); arith::registerValueBoundsOpInterfaceExternalModels(registry); bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( registry); diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt index 6149b35befe7d..f96bda603baa6 100644 --- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt @@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRArithTransforms ExpandOps.cpp IntRangeOptimizations.cpp ReifyValueBounds.cpp + ShardingInterfaceImpl.cpp UnsignedWhenEquivalent.cpp ADDITIONAL_HEADER_DIRS @@ -26,7 +27,9 @@ add_mlir_dialect_library(MLIRArithTransforms MLIRInferIntRangeInterface MLIRIR MLIRMemRefDialect + MLIRMeshDialect MLIRPass + MLIRShardingInterface MLIRTensorDialect MLIRTransforms MLIRTransformUtils diff --git a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp new file mode 100644 index 0000000000000..62d137a4cfb0e --- /dev/null +++ b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp @@ -0,0 +1,105 @@ +//===- ShardingInterfaceImpl.cpp ------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h" +#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" +#include "mlir/IR/DialectRegistry.h" +#include "llvm/Support/Debug.h" + +using namespace mlir; +using namespace mlir::arith; +using namespace mlir::mesh; + +namespace { + +// Sharding of arith.constant +// RankedTensor constants can be sharded like any other tensor. +// %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> +// %sharding = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding +// Scalar constants are always replicated and need no sharding annotation. + +struct ConstantShardingInterface + : public ShardingInterface::ExternalModel { + SmallVector getLoopIteratorTypes(Operation *op) const { + auto ndims = 0; + if (auto type = dyn_cast(op->getResult(0).getType())) { + ndims = type.getRank(); + } + return SmallVector(ndims, + utils::IteratorType::parallel); + } + + SmallVector getIndexingMaps(Operation *op) const { + if (auto type = dyn_cast(op->getResult(0).getType())) { + return SmallVector(1, {AffineMap::getMultiDimIdentityMap( + type.getRank(), op->getContext())}); + } + return {}; + } + + // Indicate failure if no result sharding exists. + // Otherwise mirror result sharding if it is a tensor constant. + // Otherwise return replication option. + FailureOr + getShardingOption(Operation *op, ArrayRef operandShardings, + ArrayRef resultShardings) const { + assert(resultShardings.size() == 1 && + "Expecting exactly one result sharding for arith.constant"); + auto resultSharding = resultShardings[0]; + if (!resultSharding) { + return failure(); + } + if (auto type = dyn_cast(op->getResult(0).getType())) { + ShardingArray axesArray(resultSharding.getSplitAxes().size()); + for (auto [i, axes] : llvm::enumerate(resultSharding.getSplitAxes())) { + axesArray[i].append(axes.asArrayRef().begin(), axes.asArrayRef().end()); + } + return ShardingOption(axesArray, resultSharding.getMeshAttr()); + } + return ShardingOption({}, resultSharding.getMeshAttr()); + } + + LogicalResult spmdize(Operation *op, ArrayRef spmdizedOperands, + ArrayRef operandShardings, + ArrayRef resultShardings, + IRMapping &spmdizationMap, + SymbolTableCollection &symbolTable, + OpBuilder &builder) const { + auto cOp = cast(op); + if (auto value = dyn_cast(cOp.getValue())) { + if (!value.isSplat() || !resultShardings[0]) { + // Currently non-splat constants are not supported. + return failure(); + } + auto sharding = resultShardings[0]; + auto newType = cast(shardType( + cOp.getType(), getMesh(op, sharding.getMeshAttr(), symbolTable), + sharding)); + auto newValue = value.resizeSplat(newType); + auto newOp = builder.create(op->getLoc(), newType, newValue); + spmdizationMap.map(op->getResult(0), newOp.getResult()); + spmdizationMap.map(op, newOp.getOperation()); + } else { + // `clone` will populate the mapping of old to new results. + (void)builder.clone(*op, spmdizationMap); + } + return success(); + } +}; +} // namespace + +void mlir::arith::registerShardingInterfaceExternalModels( + DialectRegistry ®istry) { + + registry.addExtension(+[](MLIRContext *ctx, ArithDialect *dialect) { + ConstantOp::template attachInterface(*ctx); + }); +} diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp index 33460ff25e9e4..12e1ec6d717ea 100644 --- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp +++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp @@ -194,6 +194,12 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef shardedDimsOffsets = {}, ArrayRef haloSizes = {}) { + // 0d tensors cannot be sharded and must get replicated + if (inShape.empty()) { + assert(outShape.empty()); + return; + } + std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape), llvm::adl_begin(outShape)); @@ -271,7 +277,8 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) { void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpOperand &operand, - OpBuilder &builder) { + OpBuilder &builder, + ShardOp &newShardOp) { OpBuilder::InsertionGuard insertionGuard(builder); Value operandValue = operand.get(); Operation *operandOp = operand.getOwner(); @@ -279,14 +286,20 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, ShardOp shardOp = dyn_cast(operandOp); if (shardOp && sharding == shardOp.getSharding() && !shardOp.getAnnotateForUsers()) { - // No need for anything the correct sharding is already set. + // No need for anything if the correct sharding is already set. + if (!newShardOp) { + newShardOp = shardOp; + } return; } - auto shardingOp = builder.create(operandValue.getLoc(), sharding); - auto newShardOp = - builder.create(operandValue.getLoc(), operandValue, shardingOp, - /*annotate_for_users*/ false); + if (!newShardOp) { + auto shardingOp = + builder.create(operandValue.getLoc(), sharding); + newShardOp = + builder.create(operandValue.getLoc(), operandValue, shardingOp, + /*annotate_for_users*/ false); + } IRRewriter rewriter(builder); rewriter.replaceUsesWithIf( operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) { @@ -297,17 +310,19 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, return; } - auto newShardOp2 = - builder.create(operandValue.getLoc(), newShardOp, shardingOp, - /*annotate_for_users*/ true); + auto newShardOp2 = builder.create(operandValue.getLoc(), newShardOp, + newShardOp.getSharding(), + /*annotate_for_users*/ true); rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2); + return; } void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result, OpBuilder &builder) { + ShardOp newShardOp; for (auto &use : llvm::make_early_inc_range(result.getUses())) { - maybeInsertTargetShardingAnnotation(sharding, use, builder); + maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp); } } @@ -316,9 +331,18 @@ void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding, OpBuilder &builder) { OpBuilder::InsertionGuard insertionGuard(builder); Value operandValue = operand.get(); - Operation *operandOp = operand.getOwner(); Operation *operandSrcOp = operandValue.getDefiningOp(); bool isBlockArg = !operandSrcOp; + { + auto opType = dyn_cast(operandValue.getType()); + assert(!opType || opType.getRank() > 0 || isFullReplication(sharding)); + } + if (!isa(operandValue.getType()) && operandSrcOp && + operandSrcOp->hasTrait()) { + return; + } + + Operation *operandOp = operand.getOwner(); ShardOp shardOp = dyn_cast_or_null(operandSrcOp); if (shardOp && sharding == shardOp.getSharding() && @@ -432,16 +456,14 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, ArrayRef split_axes, ArrayRef partial_axes, mesh::ReductionKind partial_type, - ArrayRef static_halo_sizes, - ArrayRef static_sharded_dims_offsets) { + ArrayRef static_halos, + ArrayRef static_offsets) { return build( b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), ::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes), ::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type), - ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halo_sizes), {}, - ::mlir::DenseI64ArrayAttr::get(b.getContext(), - static_sharded_dims_offsets), - {}); + ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {}, + ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {}); } void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, @@ -453,6 +475,18 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, {}, {}, {}, {}); } +void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, + llvm::StringRef mesh, ArrayRef split_axes, + ArrayRef static_halos, + ArrayRef static_offsets) { + return build( + b, odsState, FlatSymbolRefAttr::get(b.getContext(), mesh), + MeshAxesArrayAttr::get(b.getContext(), split_axes), {}, + ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum), + ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {}, + ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {}); +} + void ShardingOp::build( ::mlir::OpBuilder &b, ::mlir::OperationState &odsState, FlatSymbolRefAttr mesh, ArrayRef split_axes, @@ -579,9 +613,10 @@ LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) { namespace { // Sharding annotations "halo sizes" and "sharded dims offsets" // are a mix of attributes and dynamic values. This canonicalization moves -// constant values to the respective attribute lists and so minimizes the number +// constant values to the respective attribute lists, minimizing the number // of values. -class FoldDynamicLists final : public OpRewritePattern { +// It also removes sharded_dims_sizes and halos if they are effectively "empty". +class NormalizeSharding final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -593,18 +628,48 @@ class FoldDynamicLists final : public OpRewritePattern { op.getDynamicShardedDimsOffsets(), b); // No constant operands were folded, just return; - if (failed(foldDynamicIndexList(mixedHalos, /*onlyNonNegative=*/true)) && - failed(foldDynamicIndexList(mixedOffs, /*onlyNonNegative=*/true))) { - return failure(); + bool modified = succeeded(foldDynamicIndexList(mixedHalos, true)) || + succeeded(foldDynamicIndexList(mixedOffs, true)); + + auto [staticHalos, dynamicHalos] = decomposeMixedValues(mixedHalos); + auto [staticOffs, dynamicOffs] = decomposeMixedValues(mixedOffs); + + if (dynamicHalos.empty() && !staticHalos.empty()) { + if (staticHalos[0] == 0 && llvm::all_equal(staticHalos)) { + staticHalos.clear(); + modified = true; + } + } + + // Remove sharded dims offsets if they are effectively the default values, + // e.g. if they define equi-distance between all neighboring shards. + // Requires static-only offsets. Compares the first distance as the + // difference between the first two offsets. Only if all consecutive + // distances are the same, the offsets are removed. + if (dynamicOffs.empty() && !staticOffs.empty()) { + assert(staticOffs.size() >= 2); + auto diff = staticOffs[1] - staticOffs[0]; + bool all_same = staticOffs.size() > 2; + for (auto i = 2u; i < staticOffs.size(); ++i) { + if (staticOffs[i] - staticOffs[i - 1] != diff) { + all_same = false; + break; + } + } + if (all_same) { + staticOffs.clear(); + modified = true; + } } - auto halos = decomposeMixedValues(mixedHalos); - auto offs = decomposeMixedValues(mixedOffs); + if (!modified) { + return failure(); + } - op.setStaticHaloSizes(halos.first); - op.getDynamicHaloSizesMutable().assign(halos.second); - op.setStaticShardedDimsOffsets(offs.first); - op.getDynamicShardedDimsOffsetsMutable().assign(offs.second); + op.setStaticHaloSizes(staticHalos); + op.getDynamicHaloSizesMutable().assign(dynamicHalos); + op.setStaticShardedDimsOffsets(staticOffs); + op.getDynamicShardedDimsOffsetsMutable().assign(dynamicOffs); return success(); } @@ -613,7 +678,7 @@ class FoldDynamicLists final : public OpRewritePattern { void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results, mlir::MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// @@ -707,11 +772,19 @@ bool MeshSharding::operator!=(const MeshSharding &rhs) const { return !(*this == rhs); } +MeshSharding::MeshSharding(::mlir::FlatSymbolRefAttr mesh_) : mesh(mesh_) {} + MeshSharding::MeshSharding(Value rhs) { auto shardingOp = mlir::dyn_cast(rhs.getDefiningOp()); assert(shardingOp && "expected sharding op"); - *this = get(shardingOp.getMeshAttr(), shardingOp.getSplitAxes().getAxes(), - shardingOp.getPartialAxes().value_or(ArrayRef()), + auto splitAxes = shardingOp.getSplitAxes().getAxes(); + auto partialAxes = shardingOp.getPartialAxes().value_or(ArrayRef()); + // If splitAxes and partialAxes are empty, use "empty" constructor. + if (splitAxes.empty() && partialAxes.empty()) { + *this = MeshSharding(shardingOp.getMeshAttr()); + return; + } + *this = get(shardingOp.getMeshAttr(), splitAxes, partialAxes, shardingOp.getPartialType().value_or(ReductionKind::Sum), shardingOp.getStaticHaloSizes(), shardingOp.getStaticShardedDimsOffsets(), @@ -727,8 +800,11 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef static_sharded_dims_offsets_, ArrayRef dynamic_halo_sizes_, ArrayRef dynamic_sharded_dims_offsets_) { - MeshSharding res; - res.mesh = mesh_; + MeshSharding res(mesh_); + if (split_axes_.empty() && partial_axes_.empty()) { + return res; + } + res.split_axes.resize(split_axes_.size()); for (auto [i, axis] : llvm::enumerate(split_axes_)) { res.split_axes[i] = @@ -771,6 +847,53 @@ void ShardOp::getAsmResultNames( setNameFn(getResult(), "sharding_annotated"); } +namespace { +// Determine if the given ShardOp is a duplicate of another ShardOp +// on the same value. This can happen if constant values are sharded. +class FoldDuplicateShardOp final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ShardOp op, PatternRewriter &b) const override { + // Get the use-list of the value being sharded and check if it has more than + // one use. + Value value = op.getSrc(); + if (value.hasOneUse() || value.getDefiningOp()) { + return failure(); + } + + // Iterate through the uses of the value to find a duplicate ShardOp. + for (auto &use : value.getUses()) { + if (use.getOwner() != op.getOperation()) { + auto otherOp = dyn_cast(use.getOwner()); + if (!otherOp || !otherOp->isBeforeInBlock(op)) { + return failure(); + } + // Create a MeshSharding object for the current and the other ShardOp + // If the two are equal replace current op with the other op. + MeshSharding currentSharding(op.getSharding()); + MeshSharding otherSharding(otherOp.getSharding()); + if (currentSharding == otherSharding) { + b.replaceAllUsesWith(op.getResult(), otherOp.getResult()); + b.eraseOp(op.getOperation()); + } else { + // use the other sharding as input for op + op.getSrcMutable().assign(otherOp.getResult()); + } + return success(); + } + } + + return failure(); + } +}; +} // namespace + +void ShardOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results, + mlir::MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // mesh.process_multi_index op //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp index c1f4d563d5b42..f427d004c558f 100644 --- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp +++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp @@ -168,17 +168,12 @@ LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() { // check operands and results type for (Type type : op->getOperandTypes()) - if (!llvm::isa(type)) + if (!llvm::isa(type) && !type.isIntOrIndexOrFloat()) return failure(); for (Type type : op->getResultTypes()) - if (!llvm::isa(type)) + if (!llvm::isa(type) && !type.isIntOrIndexOrFloat()) return failure(); - // check loop types - SmallVector loopTypes = getLoopIteratorTypes(); - if (loopTypes.empty()) - return failure(); - // check maps SmallVector maps = getIndexingMaps(); if (maps.empty()) @@ -286,18 +281,22 @@ mesh::detail::defaultGetShardingOption(Operation *op, continue; AffineMap map = maps[numOperands + shardingIt.index()]; anyShardingInResultsOrOperands = true; - // Handle the split axes: calculate the corresponding loop index for each - // split axes sub-array, and then store the sub-array to - // shardingOption[index] - for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) { - AffineExpr expr = std::get<0>(it); - ArrayRef axes = std::get<1>(it).asArrayRef(); - auto dim = cast(expr); - unsigned index = dim.getPosition(); - visitedLoopIndices.insert(index); - if (failed(fillShardingOption(op, shardingOption, shardAttr.getMeshAttr(), - axes, index))) - return failure(); + if (shardAttr.getSplitAxes().empty() || map.getResults().empty()) { + shardingOption.mesh = shardAttr.getMeshAttr(); + } else { + // Handle the split axes: calculate the corresponding loop index for each + // split axes sub-array, and then store the sub-array to + // shardingOption[index] + for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) { + AffineExpr expr = std::get<0>(it); + ArrayRef axes = std::get<1>(it).asArrayRef(); + auto dim = cast(expr); + unsigned index = dim.getPosition(); + visitedLoopIndices.insert(index); + if (failed(fillShardingOption(op, shardingOption, + shardAttr.getMeshAttr(), axes, index))) + return failure(); + } } // Handle the partial axes: at this stage, the exact loop index/indices @@ -323,7 +322,7 @@ mesh::detail::defaultGetShardingOption(Operation *op, if (!shardAttr) continue; - anyShardingInResultsOrOperands = true; + anyShardingInResultsOrOperands = !shardAttr.getSplitAxes().empty(); AffineMap map = maps[shardingIt.index()]; unsigned numDims = map.getNumDims(); @@ -448,7 +447,16 @@ static FailureOr getSharding(OpOperand &opOperand, const ShardingOption &shardingOption, AffineMap map) { Value operandValue = opOperand.get(); - auto operandType = cast(operandValue.getType()); + auto operandType = dyn_cast(operandValue.getType()); + if (!operandType) { + if (operandValue.getType().isIntOrIndexOrFloat()) + return MeshSharding(); + return failure(); + } + // 0d tensors cannot be sharded and must get replicated + if (operandType.getRank() == 0) { + return MeshSharding(shardingOption.mesh); + } SmallVector> splitAxes(operandType.getRank()); unsigned numDims = map.getNumDims(); for (auto it : llvm::enumerate(map.getResults())) { @@ -579,7 +587,7 @@ static bool isValueCompatibleWithFullReplicationSharding(Value value, MeshSharding sharding) { if (isa(value.getType())) { - return sharding && isFullReplication(sharding); + return isFullReplication(sharding); } return !sharding; diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp index 4bd3b425219c1..8c989cce63406 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp @@ -282,11 +282,12 @@ static FailureOr selectShardingOption( // a `mesh.shard` operation for all remaining operands and results that do not // have sharding annotations. static LogicalResult visitOp(Operation *op, OpBuilder &builder) { + ShardingInterface shardingOp = llvm::dyn_cast(op); if (op->hasTrait() || - llvm::isa(op)) + (op->hasTrait() && !shardingOp) || + llvm::isa(op)) return success(); - ShardingInterface shardingOp = llvm::dyn_cast(op); if (!shardingOp) { op->emitOpError() << "sharding interface is not implemented."; return failure(); diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp index 327ea0991e4e1..601af0200e785 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp @@ -561,7 +561,8 @@ TypedValue reshard(ImplicitLocOpBuilder &builder, MeshOp mesh, TypedValue sourceUnshardedValue, TypedValue sourceShard) { // If source and destination sharding are the same, no need to do anything. - if (sourceSharding == targetSharding) { + if (sourceSharding == targetSharding || (isFullReplication(sourceSharding) && + isFullReplication(targetSharding))) { return sourceShard; } @@ -636,14 +637,6 @@ shardedBlockArgumentTypes(Block &block, return res; } -void spmdizeTriviallyShardableOperation(Operation &op, - ArrayRef spmdizedOperands, - ArrayRef operandShardings, - ArrayRef resultShardings, - IRMapping &spmdizationMap, - SymbolTableCollection &symbolTable, - OpBuilder &builder); - static LogicalResult spmdizeOperation( Operation &op, ArrayRef spmdizedOperands, ArrayRef operandShardings, @@ -703,8 +696,9 @@ static std::vector getResultShardings(Operation &op) { if (!rankedTensor) { return MeshSharding(); } - - assert(result.hasOneUse()); + if (!result.hasOneUse()) { + return MeshSharding(); + } Operation *userOp = *result.getUsers().begin(); ShardOp shardOp = llvm::cast(userOp); return MeshSharding(shardOp.getSharding()); @@ -744,6 +738,15 @@ spmdizeOperation(Operation &op, IRMapping &spmdizationMap, if (isa(op)) { return success(); } + if (auto getShardingOp = dyn_cast(op)) { + auto shardOp = getShardingOp.getSource().getDefiningOp(); + if (!shardOp) { + return op.emitError("expected a shard op as source of get_sharding"); + } + auto newSharding = builder.clone(*shardOp.getSharding().getDefiningOp()); + spmdizationMap.map(op.getResult(0), newSharding->getResult(0)); + return success(); + } ShardOp shardOp = llvm::dyn_cast(op); if (shardOp) { @@ -765,6 +768,7 @@ spmdizeOperation(Operation &op, IRMapping &spmdizationMap, static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) { + SmallVector argLocations; llvm::transform(block.getArguments(), std::back_inserter(argLocations), [](BlockArgument arg) { return arg.getLoc(); }); @@ -796,8 +800,12 @@ spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap, // Snapshot the original blocks to not mess up the iteration when adding new // blocks. SmallVector originalBlocks; - llvm::transform(op.getBlocks(), std::back_inserter(originalBlocks), - [](Block &b) { return &b; }); + for (Block &b : op.getBlocks()) { + if (llvm::any_of(b.getOperations(), + [](Operation &op) { return isa(op); })) { + originalBlocks.push_back(&b); + } + } for (Block *block : originalBlocks) { if (failed(spmdizeBlock(*block, spmdizationMap, symbolTableCollection, @@ -823,10 +831,11 @@ spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap, break; } } - assert(returnOp); - op.setType(FunctionType::get(op->getContext(), - op.getFunctionBody().front().getArgumentTypes(), - returnOp->getOperandTypes())); + if (returnOp) { + op.setType(FunctionType::get( + op->getContext(), op.getFunctionBody().front().getArgumentTypes(), + returnOp->getOperandTypes())); + } return success(); } diff --git a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp index f3e72abe7516e..b2acbf20b3fb9 100644 --- a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp +++ b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp @@ -22,10 +22,11 @@ using namespace mlir::mesh; namespace { -// Sharding of tensor.empty -struct EmptyOpShardingInterface - : public ShardingInterface::ExternalModel { +// Sharding of tensor.empty/tensor.splat +template +struct CreatorOpShardingInterface + : public ShardingInterface::ExternalModel, + OpTy> { SmallVector getLoopIteratorTypes(Operation *op) const { auto ndims = mlir::cast(op->getResult(0).getType()).getRank(); return SmallVector(ndims, @@ -38,7 +39,9 @@ struct EmptyOpShardingInterface auto type = dyn_cast(val.getType()); if (!type) return {}; - return {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)}; + return SmallVector( + op->getNumOperands() + op->getNumResults(), + {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)}); } LogicalResult spmdize(Operation *op, ArrayRef spmdizedOperands, @@ -82,8 +85,7 @@ struct EmptyOpShardingInterface newOperands.emplace_back(spmdizedOperands[++currOldOprndNum]); } } - newOp = - builder.create(op->getLoc(), shardType, newOperands); + newOp = builder.create(op->getLoc(), shardType, newOperands); spmdizationMap.map(op->getResult(0), newOp->getResult(0)); } else { // `clone` will populate the mapping of old to new results. @@ -100,6 +102,9 @@ void mlir::tensor::registerShardingInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) { - EmptyOp::template attachInterface(*ctx); + EmptyOp::template attachInterface>( + *ctx); + SplatOp::template attachInterface>( + *ctx); }); } diff --git a/mlir/test/Dialect/Arith/mesh-spmdize.mlir b/mlir/test/Dialect/Arith/mesh-spmdize.mlir new file mode 100644 index 0000000000000..6b55dd533a92c --- /dev/null +++ b/mlir/test/Dialect/Arith/mesh-spmdize.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-opt \ +// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization))" \ +// RUN: %s | FileCheck %s + +mesh.mesh @mesh4x4(shape = 4x4) + +// CHECK-LABEL: func @test_spmdize_constant +// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : +// tensor<256x1024xf32> CHECK-NEXT: [[vc434_i32:%.*]] = arith.constant 434 : +// i32 CHECK-NEXT: return [[vcst]] : tensor<256x1024xf32> +func.func @test_spmdize_constant() ->(tensor<1024x1024xf32>)attributes{llvm.emit_c_interface} { + %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding + %sharding_annotated_1 = mesh.shard %cst to %sharding_1 : tensor<1024x1024xf32> + %ci = arith.constant 434 : i32 + return %sharding_annotated_1 : tensor<1024x1024xf32> +} diff --git a/mlir/test/Dialect/Arith/sharding-propagation.mlir b/mlir/test/Dialect/Arith/sharding-propagation.mlir new file mode 100644 index 0000000000000..19eb340549b0b --- /dev/null +++ b/mlir/test/Dialect/Arith/sharding-propagation.mlir @@ -0,0 +1,54 @@ +// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation))" %s | FileCheck %s + +mesh.mesh @mesh4x4(shape = 4x4) + +// CHECK-LABEL: func.func @test_shard_constant() -> tensor<1024x1024xf32> attributes {llvm.emit_c_interface} { +// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding +// CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32> +// CHECK-NEXT: [[vcst_0:%.*]] = arith.constant 4.340000e+01 : f32 +// CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding_1:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding +// CHECK-NEXT: [[vsharding_annotated_2:%.*]] = mesh.shard [[v0]] to [[vsharding_1]] : tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding_3:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding +// CHECK-NEXT: [[vsharding_annotated_4:%.*]] = mesh.shard [[vsharding_annotated]] to [[vsharding_3]] annotate_for_users : tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding_5:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding +// CHECK-NEXT: [[vsharding_annotated_6:%.*]] = mesh.shard [[vsharding_annotated_2]] to [[vsharding_5]] annotate_for_users : tensor<1024x1024xf32> +// CHECK-NEXT: [[v1:%.*]] = linalg.add ins([[vsharding_annotated_4]], [[vcst_0]] : tensor<1024x1024xf32>, f32) outs([[vsharding_annotated_6]] : tensor<1024x1024xf32>) -> tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding_7:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding +// CHECK-NEXT: [[vsharding_annotated_8:%.*]] = mesh.shard [[v1]] to [[vsharding_7]] : tensor<1024x1024xf32> +// CHECK-NEXT: return [[vsharding_annotated_8]] : tensor<1024x1024xf32> +func.func @test_shard_constant() -> (tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} { + %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding + %sharding_annotated_1 = mesh.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32> + %ci = arith.constant 43.4e+00 : f32 + %o1 = tensor.empty() : tensor<1024x1024xf32> + %res = linalg.add ins(%sharding_annotated_1, %ci : tensor<1024x1024xf32>, f32) outs(%o1 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32> + return %res : tensor<1024x1024xf32> +} + +// CHECK-LABEL: func.func @test_shard_constant_back() -> tensor<1024x1024xf32> attributes {llvm.emit_c_interface} { +// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding +// CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32> +// CHECK-NEXT: [[vcst_0:%.*]] = arith.constant 4.340000e+01 : f32 +// CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding_1:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding +// CHECK-NEXT: [[vsharding_annotated_2:%.*]] = mesh.shard [[v0]] to [[vsharding_1]] : tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding_3:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding +// CHECK-NEXT: [[vsharding_annotated_4:%.*]] = mesh.shard [[vsharding_annotated]] to [[vsharding_3]] annotate_for_users : tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding_5:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding +// CHECK-NEXT: [[vsharding_annotated_6:%.*]] = mesh.shard [[vsharding_annotated_2]] to [[vsharding_5]] annotate_for_users : tensor<1024x1024xf32> +// CHECK-NEXT: [[v1:%.*]] = linalg.add ins([[vsharding_annotated_4]], [[vcst_0]] : tensor<1024x1024xf32>, f32) outs([[vsharding_annotated_6]] : tensor<1024x1024xf32>) -> tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding_7:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding +// CHECK-NEXT: [[vsharding_annotated_8:%.*]] = mesh.shard [[v1]] to [[vsharding_7]] : tensor<1024x1024xf32> +func.func @test_shard_constant_back() -> (tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} { + %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + %ci = arith.constant 43.4e+00 : f32 + %o1 = tensor.empty() : tensor<1024x1024xf32> + %res = linalg.add ins(%cst_1, %ci : tensor<1024x1024xf32>, f32) outs(%o1 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32> + %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding + %sharding_annotated_1 = mesh.shard %res to %sharding_1 : tensor<1024x1024xf32> + return %sharding_annotated_1 : tensor<1024x1024xf32> +} diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir index f0112d689805d..aff07bbf8a214 100644 --- a/mlir/test/Dialect/Mesh/canonicalization.mlir +++ b/mlir/test/Dialect/Mesh/canonicalization.mlir @@ -207,4 +207,42 @@ func.func @test_shard_offs() -> !mesh.sharding { // CHECK mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, 2, 3, 4, 0, 2, 3, 4, 22] : !mesh.sharding %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, %c2_i64, 3, 4, 0, %c2_i64, 3, 4, 22] : !mesh.sharding return %sharding : !mesh.sharding -} \ No newline at end of file +} + +// CHECK-LABEL: func @test_duplicate_shardops +func.func @test_duplicate_shardops() -> (tensor<1024x1024xf32>, tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} { + // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + // CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0, 1]] : !mesh.sharding + %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding + %cst_2 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + %sharding_2 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding + %sharding_annotated_2 = mesh.shard %cst_2 to %sharding_2 : tensor<1024x1024xf32> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + %sharding_3 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding + %sharding_annotated_3 = mesh.shard %cst_3 to %sharding_3 : tensor<1024x1024xf32> + // CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32> + %sharding_annotated_1 = mesh.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32> + // CHECK-NEXT: return [[vsharding_annotated]], [[vsharding_annotated]] : tensor<1024x1024xf32>, tensor<1024x1024xf32> + return %sharding_annotated_1, %sharding_annotated_2 : tensor<1024x1024xf32>, tensor<1024x1024xf32> +} + +// CHECK-LABEL: func @test_duplicate_shardops_diff +func.func @test_duplicate_shardops_diff() -> (tensor<1024x1024xf32>, tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} { + // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + // CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding + %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding + %cst_2 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + // CHECK-NEXT: [[vsharding_0:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0, 1]] : !mesh.sharding + %sharding_2 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding + // CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding_0]] : tensor<1024x1024xf32> + %sharding_annotated_2 = mesh.shard %cst_2 to %sharding_2 : tensor<1024x1024xf32> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + %sharding_3 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding + %sharding_annotated_3 = mesh.shard %cst_3 to %sharding_3 : tensor<1024x1024xf32> + // CHECK-NEXT: [[vsharding_annotated_1:%.*]] = mesh.shard [[vsharding_annotated]] to [[vsharding]] : tensor<1024x1024xf32> + %sharding_annotated_1 = mesh.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32> + // CHECK-NEXT: return [[vsharding_annotated_1]], [[vsharding_annotated]] : tensor<1024x1024xf32>, tensor<1024x1024xf32> + return %sharding_annotated_1, %sharding_annotated_2 : tensor<1024x1024xf32>, tensor<1024x1024xf32> +} diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir index 978de4939ee77..43a75bf3d8040 100644 --- a/mlir/test/Dialect/Mesh/ops.mlir +++ b/mlir/test/Dialect/Mesh/ops.mlir @@ -164,6 +164,14 @@ func.func @mesh_shard_shape() { return } +// CHECK-LABEL: func @mesh_get_sharding +// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32> +func.func @mesh_get_sharding(%arg0 : tensor<4x8xf32>) -> !mesh.sharding { + // CHECK-NEXT: mesh.get_sharding %[[ARG]] : tensor<4x8xf32> -> !mesh.sharding + %0 = mesh.get_sharding %arg0 : tensor<4x8xf32> -> !mesh.sharding + return %0 : !mesh.sharding +} + // CHECK-LABEL: func @mesh_shape func.func @mesh_shape() -> (index, index) { // CHECK: %[[RES:.*]]:2 = mesh.mesh_shape @mesh0 axes = [0, 1] : index, index diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir index c1b96fda0f4a7..59f7162e21013 100644 --- a/mlir/test/Dialect/Mesh/spmdization.mlir +++ b/mlir/test/Dialect/Mesh/spmdization.mlir @@ -4,6 +4,20 @@ mesh.mesh @mesh_1d(shape = 2) +// CHECK-LABEL: func @return_sharding +func.func @return_sharding( + // CHECK-SAME: [[ARG:%.*]]: tensor<1xf32> + %arg0: tensor<2xf32> +// CHECK-SAME: ) -> (tensor<1xf32>, !mesh.sharding) { +) -> (tensor<2xf32>, !mesh.sharding) { + %ssharding_annotated = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding + %sharding_annotated = mesh.shard %arg0 to %ssharding_annotated : tensor<2xf32> + // CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}0]] : !mesh.sharding + %r = mesh.get_sharding %sharding_annotated : tensor<2xf32> -> !mesh.sharding + // CHECK-NEXT: return [[ARG]], [[vsharding]] : tensor<1xf32>, !mesh.sharding + return %sharding_annotated, %r : tensor<2xf32>, !mesh.sharding +} + // CHECK-LABEL: func @full_replication func.func @full_replication( // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>