Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
adding libs

clang-format

renaming mesh-spmdize.cpp -> mesh-spmdize.mlir and fixing format
  • Loading branch information
fschlimb committed Feb 10, 2025
1 parent 3c76df3 commit 508095a
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 36 deletions.
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ add_mlir_dialect_library(MLIRArithTransforms
MLIRInferIntRangeInterface
MLIRIR
MLIRMemRefDialect
MLIRMeshDialect
MLIRPass
MLIRShardingInterface
MLIRTensorDialect
MLIRTransforms
MLIRTransformUtils
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ using namespace mlir::mesh;

namespace {

// Sharding of arith.empty/arith.splat
// Sharding of arith.constant
struct ConstantShardingInterface
: public ShardingInterface::ExternalModel<ConstantShardingInterface,
ConstantOp> {
Expand Down
9 changes: 6 additions & 3 deletions mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ ShardOp mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
if (shardOp && sharding == shardOp.getSharding() &&
!shardOp.getAnnotateForUsers()) {
// No need for anything the correct sharding is already set.
// No need for anything if the correct sharding is already set.
return newShardOp ? newShardOp : shardOp;
}

Expand Down Expand Up @@ -639,6 +639,8 @@ class NormalizeSharding final : public OpRewritePattern<ShardingOp> {
}
}

// Remove sharded dims offsets if they are effectively the default values,
// e.g. if they define equi-distance between all neighboring shards.
if (offs.second.empty() && !offs.first.empty()) {
assert(offs.first.size() >= 2);
auto diff = offs.first[1] - offs.first[0];
Expand Down Expand Up @@ -772,7 +774,8 @@ MeshSharding::MeshSharding(Value rhs) {
assert(shardingOp && "expected sharding op");
auto splitAxes = shardingOp.getSplitAxes().getAxes();
auto partialAxes = shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>());
if(splitAxes.empty() && partialAxes.empty()) {
// If splitAxes and partialAxes are empty, use "empty" constructor.
if (splitAxes.empty() && partialAxes.empty()) {
*this = MeshSharding(shardingOp.getMeshAttr());
return;
}
Expand All @@ -793,7 +796,7 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
ArrayRef<Value> dynamic_halo_sizes_,
ArrayRef<Value> dynamic_sharded_dims_offsets_) {
MeshSharding res(mesh_);
if(split_axes_.empty() && partial_axes_.empty()) {
if (split_axes_.empty() && partial_axes_.empty()) {
return res;
}

Expand Down
9 changes: 2 additions & 7 deletions mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,6 @@ LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
if (!llvm::isa<RankedTensorType>(type) && !type.isIntOrIndexOrFloat())
return failure();

// check loop types
// SmallVector<utils::IteratorType> loopTypes = getLoopIteratorTypes();
// if (loopTypes.empty())
// return failure();

// check maps
SmallVector<AffineMap> maps = getIndexingMaps();
if (maps.empty())
Expand Down Expand Up @@ -453,8 +448,8 @@ static FailureOr<MeshSharding> getSharding(OpOperand &opOperand,
AffineMap map) {
Value operandValue = opOperand.get();
auto operandType = dyn_cast<RankedTensorType>(operandValue.getType());
if(!operandType) {
if(operandValue.getType().isIntOrIndexOrFloat())
if (!operandType) {
if (operandValue.getType().isIntOrIndexOrFloat())
return MeshSharding();
return failure();
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,7 @@ static std::vector<MeshSharding> getResultShardings(Operation &op) {
std::vector<MeshSharding> res;
res.reserve(op.getNumResults());
llvm::transform(op.getResults(), std::back_inserter(res),
[&op](OpResult result) {
[](OpResult result) {
TypedValue<RankedTensorType> rankedTensor =
dyn_cast<TypedValue<RankedTensorType>>(result);
if (!rankedTensor) {
Expand Down
18 changes: 11 additions & 7 deletions mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ using namespace mlir::mesh;
namespace {

// Sharding of tensor.empty/tensor.splat
template<typename OpTy>
template <typename OpTy>
struct CreatorOpShardingInterface
: public ShardingInterface::ExternalModel<CreatorOpShardingInterface<OpTy>, OpTy> {
: public ShardingInterface::ExternalModel<CreatorOpShardingInterface<OpTy>,
OpTy> {
SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
auto ndims = mlir::cast<ShapedType>(op->getResult(0).getType()).getRank();
return SmallVector<utils::IteratorType>(ndims,
Expand All @@ -38,7 +39,9 @@ struct CreatorOpShardingInterface
auto type = dyn_cast<RankedTensorType>(val.getType());
if (!type)
return {};
return SmallVector<AffineMap>(op->getNumOperands() + op->getNumResults(), {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)});
return SmallVector<AffineMap>(
op->getNumOperands() + op->getNumResults(),
{AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)});
}

LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
Expand Down Expand Up @@ -82,8 +85,7 @@ struct CreatorOpShardingInterface
newOperands.emplace_back(spmdizedOperands[++currOldOprndNum]);
}
}
newOp =
builder.create<OpTy>(op->getLoc(), shardType, newOperands);
newOp = builder.create<OpTy>(op->getLoc(), shardType, newOperands);
spmdizationMap.map(op->getResult(0), newOp->getResult(0));
} else {
// `clone` will populate the mapping of old to new results.
Expand All @@ -100,7 +102,9 @@ void mlir::tensor::registerShardingInterfaceExternalModels(
DialectRegistry &registry) {

registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
EmptyOp::template attachInterface<CreatorOpShardingInterface<EmptyOp>>(*ctx);
SplatOp::template attachInterface<CreatorOpShardingInterface<SplatOp>>(*ctx);
EmptyOp::template attachInterface<CreatorOpShardingInterface<EmptyOp>>(
*ctx);
SplatOp::template attachInterface<CreatorOpShardingInterface<SplatOp>>(
*ctx);
});
}
17 changes: 0 additions & 17 deletions mlir/test/Dialect/Arith/mesh-spmdize.cpp

This file was deleted.

17 changes: 17 additions & 0 deletions mlir/test/Dialect/Arith/mesh-spmdize.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: mlir-opt \
// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization))" \
// RUN: %s | FileCheck %s

mesh.mesh @mesh4x4(shape = 4x4)

// CHECK-LABEL: func @test_spmdize_constant
// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> :
// tensor<256x1024xf32> CHECK-NEXT: [[vc434_i32:%.*]] = arith.constant 434 :
// i32 CHECK-NEXT: return [[vcst]] : tensor<256x1024xf32>
func.func @test_spmdize_constant() ->(tensor<1024x1024xf32>)attributes{llvm.emit_c_interface} {
%cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
%sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
%sharding_annotated_1 = mesh.shard %cst to %sharding_1 : tensor<1024x1024xf32>
%ci = arith.constant 434 : i32
return %sharding_annotated_1 : tensor<1024x1024xf32>
}

0 comments on commit 508095a

Please sign in to comment.