Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][mesh] Mesh fixes #124724

Merged
merged 10 commits into from
Feb 12, 2025
23 changes: 23 additions & 0 deletions mlir/include/mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
//===- ShardingInterfaceImpl.h - ------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
#define MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_

namespace mlir {

class DialectRegistry;

namespace arith {

void registerShardingInterfaceExternalModels(DialectRegistry &registry);

} // namespace arith
} // namespace mlir

#endif // MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
10 changes: 6 additions & 4 deletions mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class MeshSharding {
SmallVector<Value> dynamic_sharded_dims_offsets;

public:
MeshSharding() = default;
MeshSharding(::mlir::FlatSymbolRefAttr mesh_ = nullptr);
MeshSharding(Value rhs);
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_,
ArrayRef<MeshAxesAttr> split_axes_,
Expand All @@ -62,7 +62,7 @@ class MeshSharding {
ArrayRef<Value> dynamic_halo_sizes_ = {},
ArrayRef<Value> dynamic_sharded_dims_offsets_ = {});
::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
::llvm::StringRef getMesh() const { return mesh.getValue(); }
::llvm::StringRef getMesh() const { return mesh ? mesh.getValue() : ""; }
ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
ReductionKind getPartialType() const { return partial_type; }
Expand Down Expand Up @@ -201,10 +201,12 @@ ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
Type shardType(Type type, MeshOp mesh, MeshSharding sharding);

// Insert shard op if there is not one that already has the same sharding.
// Use newShardOp if it is not null. Otherwise create a new one.
// May insert resharding if required.
// Potentially updates newShardOp.
void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
OpOperand &operand,
OpBuilder &builder);
OpOperand &operand, OpBuilder &builder,
ShardOp &newShardOp);
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
OpBuilder &builder);
void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
Expand Down
24 changes: 23 additions & 1 deletion mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class Mesh_Op<string mnemonic, list<Trait> traits = []> :
Op<Mesh_Dialect, mnemonic, traits> {
}

def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol]> {
def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol, Pure]> {
let summary = "Description of a device/process mesh.";
let description = [{
The mesh.mesh operation is a symbol operation that identifies a specific
Expand Down Expand Up @@ -318,12 +318,33 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
"ArrayRef<MeshAxesAttr>":$split_axes,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_offsets)>,
OpBuilder<(ins "llvm::StringRef":$mesh,
"ArrayRef<MeshAxesAttr>":$split_axes,
CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes,
CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_offsets
)>,
OpBuilder<(ins "mlir::mesh::MeshSharding":$from)>
];
let hasVerifier = 1;
let hasCanonicalizer = 1;
}

def Mesh_GetShardingOp : Mesh_Op<"get_sharding", [Pure]> {
let summary = "Get the sharding of the given tensor.";
let description = [{
This operation returns the sharding of the given tensor as a MeshSharding.
}];
let arguments = (ins
AnyRankedTensor:$source
);
let results = (outs
Mesh_Sharding:$result
);
let assemblyFormat = [{
$source attr-dict `:` type($source) `->` type($result)
}];
}

def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
let summary = "Get the shard shape of a given process/device.";
let description = [{
Expand Down Expand Up @@ -460,6 +481,7 @@ def Mesh_ShardOp : Mesh_Op<"shard", [
(`annotate_for_users` $annotate_for_users^)?
attr-dict `:` type($result)
}];
let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ struct ShardingOption {
bool empty = false;
ShardingOption() = default;
ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr mesh)
: shardingArray(std::move(shardingArray)), mesh(mesh) {}
: shardingArray(std::move(shardingArray)), mesh(mesh) {
assert(this->mesh);
}
static ShardingOption makeEmpty() {
auto res = ShardingOption();
res.empty = true;
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/InitAllDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
Expand Down Expand Up @@ -158,6 +159,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
arith::registerBufferDeallocationOpInterfaceExternalModels(registry);
arith::registerBufferizableOpInterfaceExternalModels(registry);
arith::registerBufferViewFlowOpInterfaceExternalModels(registry);
arith::registerShardingInterfaceExternalModels(registry);
arith::registerValueBoundsOpInterfaceExternalModels(registry);
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
registry);
Expand Down
3 changes: 3 additions & 0 deletions mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRArithTransforms
ExpandOps.cpp
IntRangeOptimizations.cpp
ReifyValueBounds.cpp
ShardingInterfaceImpl.cpp
UnsignedWhenEquivalent.cpp

ADDITIONAL_HEADER_DIRS
Expand All @@ -26,7 +27,9 @@ add_mlir_dialect_library(MLIRArithTransforms
MLIRInferIntRangeInterface
MLIRIR
MLIRMemRefDialect
MLIRMeshDialect
MLIRPass
MLIRShardingInterface
MLIRTensorDialect
MLIRTransforms
MLIRTransformUtils
Expand Down
105 changes: 105 additions & 0 deletions mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
//===- ShardingInterfaceImpl.cpp ------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
#include "mlir/IR/DialectRegistry.h"
#include "llvm/Support/Debug.h"

using namespace mlir;
using namespace mlir::arith;
using namespace mlir::mesh;

namespace {

// Sharding of arith.constant
// RankedTensor constants can be sharded like any other tensor.
// %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
// %sharding = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
// Scalar constants are always replicated and need no sharding annotation.

struct ConstantShardingInterface
: public ShardingInterface::ExternalModel<ConstantShardingInterface,
ConstantOp> {
SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
auto ndims = 0;
if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
ndims = type.getRank();
}
return SmallVector<utils::IteratorType>(ndims,
utils::IteratorType::parallel);
}

SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
return SmallVector<AffineMap>(1, {AffineMap::getMultiDimIdentityMap(
type.getRank(), op->getContext())});
}
return {};
}

// Indicate failure if no result sharding exists.
// Otherwise mirror result sharding if it is a tensor constant.
// Otherwise return replication option.
FailureOr<ShardingOption>
getShardingOption(Operation *op, ArrayRef<MeshSharding> operandShardings,
ArrayRef<MeshSharding> resultShardings) const {
assert(resultShardings.size() == 1 &&
"Expecting exactly one result sharding for arith.constant");
auto resultSharding = resultShardings[0];
if (!resultSharding) {
return failure();
}
if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
ShardingArray axesArray(resultSharding.getSplitAxes().size());
for (auto [i, axes] : llvm::enumerate(resultSharding.getSplitAxes())) {
axesArray[i].append(axes.asArrayRef().begin(), axes.asArrayRef().end());
}
return ShardingOption(axesArray, resultSharding.getMeshAttr());
}
return ShardingOption({}, resultSharding.getMeshAttr());
}

LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
ArrayRef<MeshSharding> operandShardings,
ArrayRef<MeshSharding> resultShardings,
IRMapping &spmdizationMap,
SymbolTableCollection &symbolTable,
OpBuilder &builder) const {
auto cOp = cast<ConstantOp>(op);
if (auto value = dyn_cast<DenseIntOrFPElementsAttr>(cOp.getValue())) {
if (!value.isSplat() || !resultShardings[0]) {
// Currently non-splat constants are not supported.
return failure();
}
auto sharding = resultShardings[0];
auto newType = cast<RankedTensorType>(shardType(
cOp.getType(), getMesh(op, sharding.getMeshAttr(), symbolTable),
sharding));
auto newValue = value.resizeSplat(newType);
auto newOp = builder.create<ConstantOp>(op->getLoc(), newType, newValue);
spmdizationMap.map(op->getResult(0), newOp.getResult());
spmdizationMap.map(op, newOp.getOperation());
} else {
// `clone` will populate the mapping of old to new results.
(void)builder.clone(*op, spmdizationMap);
}
return success();
}
};
} // namespace

void mlir::arith::registerShardingInterfaceExternalModels(
DialectRegistry &registry) {

registry.addExtension(+[](MLIRContext *ctx, ArithDialect *dialect) {
ConstantOp::template attachInterface<ConstantShardingInterface>(*ctx);
});
}
Loading