Skip to content

Commit

Permalink
[MLIR][mesh] Mesh fixes (llvm#124724)
Browse files Browse the repository at this point in the history
A collection of fixes to the mesh dialect
- allow constants in sharding propagation/spmdization
- fixes to tensor replication (e.g. 0d tensors)
- improved canonicalization
- sharding propagation incorrectly generated too many ShardOps
New operation `mesh.GetShardOp` enables exchanging sharding information
(like on function boundaries)
  • Loading branch information
fschlimb authored and flovent committed Feb 13, 2025
1 parent 0994ad7 commit 8d93683
Show file tree
Hide file tree
Showing 17 changed files with 525 additions and 89 deletions.
23 changes: 23 additions & 0 deletions mlir/include/mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
//===- ShardingInterfaceImpl.h - ------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
#define MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_

namespace mlir {

class DialectRegistry;

namespace arith {

void registerShardingInterfaceExternalModels(DialectRegistry &registry);

} // namespace arith
} // namespace mlir

#endif // MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
10 changes: 6 additions & 4 deletions mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class MeshSharding {
SmallVector<Value> dynamic_sharded_dims_offsets;

public:
MeshSharding() = default;
MeshSharding(::mlir::FlatSymbolRefAttr mesh_ = nullptr);
MeshSharding(Value rhs);
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_,
ArrayRef<MeshAxesAttr> split_axes_,
Expand All @@ -62,7 +62,7 @@ class MeshSharding {
ArrayRef<Value> dynamic_halo_sizes_ = {},
ArrayRef<Value> dynamic_sharded_dims_offsets_ = {});
::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
::llvm::StringRef getMesh() const { return mesh.getValue(); }
::llvm::StringRef getMesh() const { return mesh ? mesh.getValue() : ""; }
ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
ReductionKind getPartialType() const { return partial_type; }
Expand Down Expand Up @@ -201,10 +201,12 @@ ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
Type shardType(Type type, MeshOp mesh, MeshSharding sharding);

// Insert shard op if there is not one that already has the same sharding.
// Use newShardOp if it is not null. Otherwise create a new one.
// May insert resharding if required.
// Potentially updates newShardOp.
void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
OpOperand &operand,
OpBuilder &builder);
OpOperand &operand, OpBuilder &builder,
ShardOp &newShardOp);
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
OpBuilder &builder);
void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
Expand Down
24 changes: 23 additions & 1 deletion mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class Mesh_Op<string mnemonic, list<Trait> traits = []> :
Op<Mesh_Dialect, mnemonic, traits> {
}

def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol]> {
def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol, Pure]> {
let summary = "Description of a device/process mesh.";
let description = [{
The mesh.mesh operation is a symbol operation that identifies a specific
Expand Down Expand Up @@ -318,12 +318,33 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
"ArrayRef<MeshAxesAttr>":$split_axes,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_offsets)>,
OpBuilder<(ins "llvm::StringRef":$mesh,
"ArrayRef<MeshAxesAttr>":$split_axes,
CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes,
CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_offsets
)>,
OpBuilder<(ins "mlir::mesh::MeshSharding":$from)>
];
let hasVerifier = 1;
let hasCanonicalizer = 1;
}

def Mesh_GetShardingOp : Mesh_Op<"get_sharding", [Pure]> {
let summary = "Get the sharding of the given tensor.";
let description = [{
This operation returns the sharding of the given tensor as a MeshSharding.
}];
let arguments = (ins
AnyRankedTensor:$source
);
let results = (outs
Mesh_Sharding:$result
);
let assemblyFormat = [{
$source attr-dict `:` type($source) `->` type($result)
}];
}

def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
let summary = "Get the shard shape of a given process/device.";
let description = [{
Expand Down Expand Up @@ -460,6 +481,7 @@ def Mesh_ShardOp : Mesh_Op<"shard", [
(`annotate_for_users` $annotate_for_users^)?
attr-dict `:` type($result)
}];
let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ struct ShardingOption {
bool empty = false;
ShardingOption() = default;
ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr mesh)
: shardingArray(std::move(shardingArray)), mesh(mesh) {}
: shardingArray(std::move(shardingArray)), mesh(mesh) {
assert(this->mesh);
}
static ShardingOption makeEmpty() {
auto res = ShardingOption();
res.empty = true;
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/InitAllDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
Expand Down Expand Up @@ -158,6 +159,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
arith::registerBufferDeallocationOpInterfaceExternalModels(registry);
arith::registerBufferizableOpInterfaceExternalModels(registry);
arith::registerBufferViewFlowOpInterfaceExternalModels(registry);
arith::registerShardingInterfaceExternalModels(registry);
arith::registerValueBoundsOpInterfaceExternalModels(registry);
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
registry);
Expand Down
3 changes: 3 additions & 0 deletions mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRArithTransforms
ExpandOps.cpp
IntRangeOptimizations.cpp
ReifyValueBounds.cpp
ShardingInterfaceImpl.cpp
UnsignedWhenEquivalent.cpp

ADDITIONAL_HEADER_DIRS
Expand All @@ -26,7 +27,9 @@ add_mlir_dialect_library(MLIRArithTransforms
MLIRInferIntRangeInterface
MLIRIR
MLIRMemRefDialect
MLIRMeshDialect
MLIRPass
MLIRShardingInterface
MLIRTensorDialect
MLIRTransforms
MLIRTransformUtils
Expand Down
105 changes: 105 additions & 0 deletions mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
//===- ShardingInterfaceImpl.cpp ------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
#include "mlir/IR/DialectRegistry.h"
#include "llvm/Support/Debug.h"

using namespace mlir;
using namespace mlir::arith;
using namespace mlir::mesh;

namespace {

// Sharding of arith.constant
// RankedTensor constants can be sharded like any other tensor.
// %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
// %sharding = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
// Scalar constants are always replicated and need no sharding annotation.

struct ConstantShardingInterface
: public ShardingInterface::ExternalModel<ConstantShardingInterface,
ConstantOp> {
SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
auto ndims = 0;
if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
ndims = type.getRank();
}
return SmallVector<utils::IteratorType>(ndims,
utils::IteratorType::parallel);
}

SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
return SmallVector<AffineMap>(1, {AffineMap::getMultiDimIdentityMap(
type.getRank(), op->getContext())});
}
return {};
}

// Indicate failure if no result sharding exists.
// Otherwise mirror result sharding if it is a tensor constant.
// Otherwise return replication option.
FailureOr<ShardingOption>
getShardingOption(Operation *op, ArrayRef<MeshSharding> operandShardings,
ArrayRef<MeshSharding> resultShardings) const {
assert(resultShardings.size() == 1 &&
"Expecting exactly one result sharding for arith.constant");
auto resultSharding = resultShardings[0];
if (!resultSharding) {
return failure();
}
if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
ShardingArray axesArray(resultSharding.getSplitAxes().size());
for (auto [i, axes] : llvm::enumerate(resultSharding.getSplitAxes())) {
axesArray[i].append(axes.asArrayRef().begin(), axes.asArrayRef().end());
}
return ShardingOption(axesArray, resultSharding.getMeshAttr());
}
return ShardingOption({}, resultSharding.getMeshAttr());
}

LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
ArrayRef<MeshSharding> operandShardings,
ArrayRef<MeshSharding> resultShardings,
IRMapping &spmdizationMap,
SymbolTableCollection &symbolTable,
OpBuilder &builder) const {
auto cOp = cast<ConstantOp>(op);
if (auto value = dyn_cast<DenseIntOrFPElementsAttr>(cOp.getValue())) {
if (!value.isSplat() || !resultShardings[0]) {
// Currently non-splat constants are not supported.
return failure();
}
auto sharding = resultShardings[0];
auto newType = cast<RankedTensorType>(shardType(
cOp.getType(), getMesh(op, sharding.getMeshAttr(), symbolTable),
sharding));
auto newValue = value.resizeSplat(newType);
auto newOp = builder.create<ConstantOp>(op->getLoc(), newType, newValue);
spmdizationMap.map(op->getResult(0), newOp.getResult());
spmdizationMap.map(op, newOp.getOperation());
} else {
// `clone` will populate the mapping of old to new results.
(void)builder.clone(*op, spmdizationMap);
}
return success();
}
};
} // namespace

void mlir::arith::registerShardingInterfaceExternalModels(
DialectRegistry &registry) {

registry.addExtension(+[](MLIRContext *ctx, ArithDialect *dialect) {
ConstantOp::template attachInterface<ConstantShardingInterface>(*ctx);
});
}
Loading

0 comments on commit 8d93683

Please sign in to comment.