From 508095ac58bd85dfd6a0cdc7bad93ef57fbc8610 Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Tue, 28 Jan 2025 10:39:46 +0100 Subject: [PATCH] comments adding libs clang-format renaming mesh-spmdize.cpp -> mesh-spmdize.mlir and fixing format --- .../Dialect/Arith/Transforms/CMakeLists.txt | 2 ++ .../Arith/Transforms/ShardingInterfaceImpl.cpp | 2 +- mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 9 ++++++--- .../Mesh/Interfaces/ShardingInterface.cpp | 9 ++------- .../Dialect/Mesh/Transforms/Spmdization.cpp | 2 +- .../Extensions/MeshShardingExtensions.cpp | 18 +++++++++++------- mlir/test/Dialect/Arith/mesh-spmdize.cpp | 17 ----------------- mlir/test/Dialect/Arith/mesh-spmdize.mlir | 17 +++++++++++++++++ 8 files changed, 40 insertions(+), 36 deletions(-) delete mode 100644 mlir/test/Dialect/Arith/mesh-spmdize.cpp create mode 100644 mlir/test/Dialect/Arith/mesh-spmdize.mlir diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt index 30dd84aff120f..f96bda603baa6 100644 --- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt @@ -27,7 +27,9 @@ add_mlir_dialect_library(MLIRArithTransforms MLIRInferIntRangeInterface MLIRIR MLIRMemRefDialect + MLIRMeshDialect MLIRPass + MLIRShardingInterface MLIRTensorDialect MLIRTransforms MLIRTransformUtils diff --git a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp index fc033294eb01b..f31db49067756 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp @@ -19,7 +19,7 @@ using namespace mlir::mesh; namespace { -// Sharding of arith.empty/arith.splat +// Sharding of arith.constant struct ConstantShardingInterface : public ShardingInterface::ExternalModel { diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp index f84d467048522..c789fc527e3f6 100644 --- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp +++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp @@ -286,7 +286,7 @@ ShardOp mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, ShardOp shardOp = dyn_cast(operandOp); if (shardOp && sharding == shardOp.getSharding() && !shardOp.getAnnotateForUsers()) { - // No need for anything the correct sharding is already set. + // No need for anything if the correct sharding is already set. return newShardOp ? newShardOp : shardOp; } @@ -639,6 +639,8 @@ class NormalizeSharding final : public OpRewritePattern { } } + // Remove sharded dims offsets if they are effectively the default values, + // e.g. if they define equi-distance between all neighboring shards. if (offs.second.empty() && !offs.first.empty()) { assert(offs.first.size() >= 2); auto diff = offs.first[1] - offs.first[0]; @@ -772,7 +774,8 @@ MeshSharding::MeshSharding(Value rhs) { assert(shardingOp && "expected sharding op"); auto splitAxes = shardingOp.getSplitAxes().getAxes(); auto partialAxes = shardingOp.getPartialAxes().value_or(ArrayRef()); - if(splitAxes.empty() && partialAxes.empty()) { + // If splitAxes and partialAxes are empty, use "empty" constructor. + if (splitAxes.empty() && partialAxes.empty()) { *this = MeshSharding(shardingOp.getMeshAttr()); return; } @@ -793,7 +796,7 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef dynamic_halo_sizes_, ArrayRef dynamic_sharded_dims_offsets_) { MeshSharding res(mesh_); - if(split_axes_.empty() && partial_axes_.empty()) { + if (split_axes_.empty() && partial_axes_.empty()) { return res; } diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp index aaffe759b0cef..f427d004c558f 100644 --- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp +++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp @@ -174,11 +174,6 @@ LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() { if (!llvm::isa(type) && !type.isIntOrIndexOrFloat()) return failure(); - // check loop types - // SmallVector loopTypes = getLoopIteratorTypes(); - // if (loopTypes.empty()) - // return failure(); - // check maps SmallVector maps = getIndexingMaps(); if (maps.empty()) @@ -453,8 +448,8 @@ static FailureOr getSharding(OpOperand &opOperand, AffineMap map) { Value operandValue = opOperand.get(); auto operandType = dyn_cast(operandValue.getType()); - if(!operandType) { - if(operandValue.getType().isIntOrIndexOrFloat()) + if (!operandType) { + if (operandValue.getType().isIntOrIndexOrFloat()) return MeshSharding(); return failure(); } diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp index 4ec8bbc0dff7d..601af0200e785 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp @@ -690,7 +690,7 @@ static std::vector getResultShardings(Operation &op) { std::vector res; res.reserve(op.getNumResults()); llvm::transform(op.getResults(), std::back_inserter(res), - [&op](OpResult result) { + [](OpResult result) { TypedValue rankedTensor = dyn_cast>(result); if (!rankedTensor) { diff --git a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp index 6bb5d4a66f39e..b2acbf20b3fb9 100644 --- a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp +++ b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp @@ -23,9 +23,10 @@ using namespace mlir::mesh; namespace { // Sharding of tensor.empty/tensor.splat -template +template struct CreatorOpShardingInterface - : public ShardingInterface::ExternalModel, OpTy> { + : public ShardingInterface::ExternalModel, + OpTy> { SmallVector getLoopIteratorTypes(Operation *op) const { auto ndims = mlir::cast(op->getResult(0).getType()).getRank(); return SmallVector(ndims, @@ -38,7 +39,9 @@ struct CreatorOpShardingInterface auto type = dyn_cast(val.getType()); if (!type) return {}; - return SmallVector(op->getNumOperands() + op->getNumResults(), {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)}); + return SmallVector( + op->getNumOperands() + op->getNumResults(), + {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)}); } LogicalResult spmdize(Operation *op, ArrayRef spmdizedOperands, @@ -82,8 +85,7 @@ struct CreatorOpShardingInterface newOperands.emplace_back(spmdizedOperands[++currOldOprndNum]); } } - newOp = - builder.create(op->getLoc(), shardType, newOperands); + newOp = builder.create(op->getLoc(), shardType, newOperands); spmdizationMap.map(op->getResult(0), newOp->getResult(0)); } else { // `clone` will populate the mapping of old to new results. @@ -100,7 +102,9 @@ void mlir::tensor::registerShardingInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) { - EmptyOp::template attachInterface>(*ctx); - SplatOp::template attachInterface>(*ctx); + EmptyOp::template attachInterface>( + *ctx); + SplatOp::template attachInterface>( + *ctx); }); } diff --git a/mlir/test/Dialect/Arith/mesh-spmdize.cpp b/mlir/test/Dialect/Arith/mesh-spmdize.cpp deleted file mode 100644 index 0688e14b1cf72..0000000000000 --- a/mlir/test/Dialect/Arith/mesh-spmdize.cpp +++ /dev/null @@ -1,17 +0,0 @@ -// RUN: mlir-opt \ -// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization))" \ -// RUN: %s | FileCheck %s - -mesh.mesh @mesh4x4(shape = 4x4) - -// CHECK-LABEL: func @test_spmdize_constant -// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<256x1024xf32> -// CHECK-NEXT: [[vc434_i32:%.*]] = arith.constant 434 : i32 -// CHECK-NEXT: return [[vcst]] : tensor<256x1024xf32> -func.func @test_spmdize_constant() -> (tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} { - %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> - %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding - %sharding_annotated_1 = mesh.shard %cst to %sharding_1 : tensor<1024x1024xf32> - %ci = arith.constant 434 : i32 - return %sharding_annotated_1 : tensor<1024x1024xf32> -} diff --git a/mlir/test/Dialect/Arith/mesh-spmdize.mlir b/mlir/test/Dialect/Arith/mesh-spmdize.mlir new file mode 100644 index 0000000000000..6b55dd533a92c --- /dev/null +++ b/mlir/test/Dialect/Arith/mesh-spmdize.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-opt \ +// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization))" \ +// RUN: %s | FileCheck %s + +mesh.mesh @mesh4x4(shape = 4x4) + +// CHECK-LABEL: func @test_spmdize_constant +// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : +// tensor<256x1024xf32> CHECK-NEXT: [[vc434_i32:%.*]] = arith.constant 434 : +// i32 CHECK-NEXT: return [[vcst]] : tensor<256x1024xf32> +func.func @test_spmdize_constant() ->(tensor<1024x1024xf32>)attributes{llvm.emit_c_interface} { + %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding + %sharding_annotated_1 = mesh.shard %cst to %sharding_1 : tensor<1024x1024xf32> + %ci = arith.constant 434 : i32 + return %sharding_annotated_1 : tensor<1024x1024xf32> +}