Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][mesh] Mesh fixes #124724

Merged
merged 10 commits into from
Feb 12, 2025
Merged

[MLIR][mesh] Mesh fixes #124724

merged 10 commits into from
Feb 12, 2025

Conversation

fschlimb
Copy link
Contributor

A collection of fixes to the mesh dialect

  • allow constants in sharding propagation/spmdization
  • fixes to tensor replication (e.g. 0d tensors)
  • improved canonicalization
  • sharding propagation incorrectly generated too many ShardOps
    New operation mesh.GetShardOp enables exchanging sharding information (like on function boundaries)

@yaochengji @AntonLydike

@fschlimb fschlimb requested review from sogartar and mfrancio January 28, 2025 09:56
Copy link

github-actions bot commented Jan 28, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@fschlimb
Copy link
Contributor Author

fschlimb commented Feb 3, 2025

ping @sogartar @mfrancio @yaochengji Could you please have a look?

@fschlimb
Copy link
Contributor Author

ping @sogartar @mfrancio @yaochengji

@llvmbot
Copy link
Member

llvmbot commented Feb 10, 2025

@llvm/pr-subscribers-mlir-tensor
@llvm/pr-subscribers-mlir-arith

@llvm/pr-subscribers-mlir

Author: Frank Schlimbach (fschlimb)

Changes

A collection of fixes to the mesh dialect

  • allow constants in sharding propagation/spmdization
  • fixes to tensor replication (e.g. 0d tensors)
  • improved canonicalization
  • sharding propagation incorrectly generated too many ShardOps
    New operation mesh.GetShardOp enables exchanging sharding information (like on function boundaries)

@yaochengji @AntonLydike


Patch is 48.93 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/124724.diff

17 Files Affected:

  • (added) mlir/include/mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h (+23)
  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+8-5)
  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+23-1)
  • (modified) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h (+3-1)
  • (modified) mlir/include/mlir/InitAllDialects.h (+2)
  • (modified) mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt (+3)
  • (added) mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp (+99)
  • (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+150-32)
  • (modified) mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp (+30-22)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp (+3-2)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+26-17)
  • (modified) mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp (+13-8)
  • (added) mlir/test/Dialect/Arith/mesh-spmdize.mlir (+17)
  • (added) mlir/test/Dialect/Arith/sharding-propagation.mlir (+54)
  • (modified) mlir/test/Dialect/Mesh/canonicalization.mlir (+39-1)
  • (modified) mlir/test/Dialect/Mesh/ops.mlir (+10)
  • (modified) mlir/test/Dialect/Mesh/spmdization.mlir (+14)
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h
new file mode 100644
index 000000000000000..5addffbe571bee1
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h
@@ -0,0 +1,23 @@
+//===- ShardingInterfaceImpl.h - ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
+#define MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace arith {
+
+void registerShardingInterfaceExternalModels(DialectRegistry &registry);
+
+} // namespace arith
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 75cb096130ca6e4..7de7842baf98abf 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -51,7 +51,7 @@ class MeshSharding {
   SmallVector<Value> dynamic_sharded_dims_offsets;
 
 public:
-  MeshSharding() = default;
+  MeshSharding(::mlir::FlatSymbolRefAttr mesh_ = nullptr);
   MeshSharding(Value rhs);
   static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_,
                           ArrayRef<MeshAxesAttr> split_axes_,
@@ -62,7 +62,7 @@ class MeshSharding {
                           ArrayRef<Value> dynamic_halo_sizes_ = {},
                           ArrayRef<Value> dynamic_sharded_dims_offsets_ = {});
   ::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
-  ::llvm::StringRef getMesh() const { return mesh.getValue(); }
+  ::llvm::StringRef getMesh() const { return mesh ? mesh.getValue() : ""; }
   ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
   ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
   ReductionKind getPartialType() const { return partial_type; }
@@ -201,10 +201,13 @@ ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
 Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
 
 // Insert shard op if there is not one that already has the same sharding.
+// Use newShardOp if it is not null. Otherwise create a new one.
 // May insert resharding if required.
-void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
-                                         OpOperand &operand,
-                                         OpBuilder &builder);
+// Return the target ShardOP (new or existing).
+ShardOp maybeInsertTargetShardingAnnotation(MeshSharding sharding,
+                                            OpOperand &operand,
+                                            OpBuilder &builder,
+                                            ShardOp newShardOp);
 void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
                                          OpBuilder &builder);
 void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 6039e61a93fadc5..031e6f63bcb42cc 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -28,7 +28,7 @@ class Mesh_Op<string mnemonic, list<Trait> traits = []> :
     Op<Mesh_Dialect, mnemonic, traits> {
 }
 
-def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol]> {
+def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol, Pure]> {
   let summary = "Description of a device/process mesh.";
   let description = [{
     The mesh.mesh operation is a symbol operation that identifies a specific
@@ -318,12 +318,33 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
                    "ArrayRef<MeshAxesAttr>":$split_axes,
                    "::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes,
                    "::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_offsets)>,
+    OpBuilder<(ins "llvm::StringRef":$mesh,
+                   "ArrayRef<MeshAxesAttr>":$split_axes,
+                   CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes,
+                   CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_offsets
+    )>,
     OpBuilder<(ins "mlir::mesh::MeshSharding":$from)>
   ];
   let hasVerifier = 1;
   let hasCanonicalizer = 1;
 }
 
+def Mesh_GetShardingOp : Mesh_Op<"get_sharding", [Pure]> {
+  let summary = "Get the sharding of the given tensor.";
+  let description = [{
+    This operation returns the sharding of the given tensor as a MeshSharding.
+  }];
+  let arguments = (ins
+    AnyRankedTensor:$source
+  );
+  let results = (outs
+    Mesh_Sharding:$result
+  );
+  let assemblyFormat = [{
+    $source attr-dict `:` type($source) `->` type($result)
+  }];
+}
+
 def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
   let summary = "Get the shard shape of a given process/device.";
   let description = [{
@@ -460,6 +481,7 @@ def Mesh_ShardOp : Mesh_Op<"shard", [
       (`annotate_for_users` $annotate_for_users^)?
       attr-dict `:` type($result)
   }];
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index b4d25cef05a7b96..14aad7f9f6783d9 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -36,7 +36,9 @@ struct ShardingOption {
   bool empty = false;
   ShardingOption() = default;
   ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr mesh)
-      : shardingArray(std::move(shardingArray)), mesh(mesh) {}
+      : shardingArray(std::move(shardingArray)), mesh(mesh) {
+    assert(this->mesh);
+  }
   static ShardingOption makeEmpty() {
     auto res = ShardingOption();
     res.empty = true;
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 0da82825c82878a..33bc89279c08c32 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -23,6 +23,7 @@
 #include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h"
 #include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h"
 #include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
@@ -158,6 +159,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   arith::registerBufferDeallocationOpInterfaceExternalModels(registry);
   arith::registerBufferizableOpInterfaceExternalModels(registry);
   arith::registerBufferViewFlowOpInterfaceExternalModels(registry);
+  arith::registerShardingInterfaceExternalModels(registry);
   arith::registerValueBoundsOpInterfaceExternalModels(registry);
   bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
       registry);
diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index 6149b35befe7de2..f96bda603baa63d 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRArithTransforms
   ExpandOps.cpp
   IntRangeOptimizations.cpp
   ReifyValueBounds.cpp
+  ShardingInterfaceImpl.cpp
   UnsignedWhenEquivalent.cpp
 
   ADDITIONAL_HEADER_DIRS
@@ -26,7 +27,9 @@ add_mlir_dialect_library(MLIRArithTransforms
   MLIRInferIntRangeInterface
   MLIRIR
   MLIRMemRefDialect
+  MLIRMeshDialect
   MLIRPass
+  MLIRShardingInterface
   MLIRTensorDialect
   MLIRTransforms
   MLIRTransformUtils
diff --git a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
new file mode 100644
index 000000000000000..f31db4906775687
--- /dev/null
+++ b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
@@ -0,0 +1,99 @@
+//===- ShardingInterfaceImpl.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "llvm/Support/Debug.h"
+
+using namespace mlir;
+using namespace mlir::arith;
+using namespace mlir::mesh;
+
+namespace {
+
+// Sharding of arith.constant
+struct ConstantShardingInterface
+    : public ShardingInterface::ExternalModel<ConstantShardingInterface,
+                                              ConstantOp> {
+  SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
+    auto ndims = 0;
+    if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
+      ndims = type.getRank();
+    }
+    return SmallVector<utils::IteratorType>(ndims,
+                                            utils::IteratorType::parallel);
+  }
+
+  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+    if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
+      return SmallVector<AffineMap>(1, {AffineMap::getMultiDimIdentityMap(
+                                           type.getRank(), op->getContext())});
+    }
+    return {};
+  }
+
+  // Indicate failure if no result sharding exists.
+  // Otherwise mirror result sharding if it is a tensor constant.
+  // Otherwise return replication option.
+  FailureOr<ShardingOption>
+  getShardingOption(Operation *op, ArrayRef<MeshSharding> operandShardings,
+                    ArrayRef<MeshSharding> resultShardings) const {
+    if (!resultShardings[0]) {
+      return failure();
+    }
+    if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
+      ShardingArray axesArray(resultShardings[0].getSplitAxes().size());
+      for (auto [i, axes] :
+           llvm::enumerate(resultShardings[0].getSplitAxes())) {
+        axesArray[i].append(axes.asArrayRef().begin(), axes.asArrayRef().end());
+      }
+      return ShardingOption(axesArray, resultShardings[0].getMeshAttr());
+    }
+    return ShardingOption({}, resultShardings[0].getMeshAttr());
+  }
+
+  LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
+                        ArrayRef<MeshSharding> operandShardings,
+                        ArrayRef<MeshSharding> resultShardings,
+                        IRMapping &spmdizationMap,
+                        SymbolTableCollection &symbolTable,
+                        OpBuilder &builder) const {
+    auto cOp = cast<ConstantOp>(op);
+    auto value = dyn_cast<DenseIntOrFPElementsAttr>(cOp.getValue());
+    if (value) {
+      if (!value.isSplat() || !resultShardings[0]) {
+        // Currently non-splat constants are not supported.
+        return failure();
+      }
+      auto sharding = resultShardings[0];
+      auto newType = cast<RankedTensorType>(shardType(
+          cOp.getType(), getMesh(op, sharding.getMeshAttr(), symbolTable),
+          sharding));
+      auto newValue = value.resizeSplat(newType);
+      auto newOp = builder.create<ConstantOp>(op->getLoc(), newType, newValue);
+      spmdizationMap.map(op->getResult(0), newOp.getResult());
+      spmdizationMap.map(op, newOp.getOperation());
+    } else {
+      // `clone` will populate the mapping of old to new results.
+      (void)builder.clone(*op, spmdizationMap);
+    }
+    return success();
+  }
+};
+} // namespace
+
+void mlir::arith::registerShardingInterfaceExternalModels(
+    DialectRegistry &registry) {
+
+  registry.addExtension(+[](MLIRContext *ctx, ArithDialect *dialect) {
+    ConstantOp::template attachInterface<ConstantShardingInterface>(*ctx);
+  });
+}
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 33460ff25e9e45d..c789fc527e3f680 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -194,6 +194,12 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
                        const SplitAxes &splitAxes, OutShape &outShape,
                        ArrayRef<int64_t> shardedDimsOffsets = {},
                        ArrayRef<int64_t> haloSizes = {}) {
+  // 0d tensors cannot be sharded and must get replicated
+  if (inShape.empty()) {
+    assert(outShape.empty());
+    return;
+  }
+
   std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
             llvm::adl_begin(outShape));
 
@@ -269,9 +275,10 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
   return type;
 }
 
-void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
-                                                     OpOperand &operand,
-                                                     OpBuilder &builder) {
+ShardOp mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
+                                                        OpOperand &operand,
+                                                        OpBuilder &builder,
+                                                        ShardOp newShardOp) {
   OpBuilder::InsertionGuard insertionGuard(builder);
   Value operandValue = operand.get();
   Operation *operandOp = operand.getOwner();
@@ -279,14 +286,17 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
   ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
   if (shardOp && sharding == shardOp.getSharding() &&
       !shardOp.getAnnotateForUsers()) {
-    // No need for anything the correct sharding is already set.
-    return;
+    // No need for anything if the correct sharding is already set.
+    return newShardOp ? newShardOp : shardOp;
   }
 
-  auto shardingOp = builder.create<ShardingOp>(operandValue.getLoc(), sharding);
-  auto newShardOp =
-      builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
-                              /*annotate_for_users*/ false);
+  if (!newShardOp) {
+    auto shardingOp =
+        builder.create<ShardingOp>(operandValue.getLoc(), sharding);
+    newShardOp =
+        builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
+                                /*annotate_for_users*/ false);
+  }
   IRRewriter rewriter(builder);
   rewriter.replaceUsesWithIf(
       operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
@@ -294,20 +304,23 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
       });
 
   if (!shardOp || shardOp.getAnnotateForUsers()) {
-    return;
+    return newShardOp;
   }
 
-  auto newShardOp2 =
-      builder.create<ShardOp>(operandValue.getLoc(), newShardOp, shardingOp,
-                              /*annotate_for_users*/ true);
+  auto newShardOp2 = builder.create<ShardOp>(operandValue.getLoc(), newShardOp,
+                                             newShardOp.getSharding(),
+                                             /*annotate_for_users*/ true);
   rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
+  return newShardOp;
 }
 
 void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
                                                      OpResult result,
                                                      OpBuilder &builder) {
+  ShardOp newShardOp;
   for (auto &use : llvm::make_early_inc_range(result.getUses())) {
-    maybeInsertTargetShardingAnnotation(sharding, use, builder);
+    newShardOp =
+        maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp);
   }
 }
 
@@ -316,9 +329,18 @@ void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding,
                                                      OpBuilder &builder) {
   OpBuilder::InsertionGuard insertionGuard(builder);
   Value operandValue = operand.get();
-  Operation *operandOp = operand.getOwner();
   Operation *operandSrcOp = operandValue.getDefiningOp();
   bool isBlockArg = !operandSrcOp;
+  {
+    auto opType = dyn_cast<mlir::RankedTensorType>(operandValue.getType());
+    assert(!opType || opType.getRank() > 0 || isFullReplication(sharding));
+  }
+  if (!isa<RankedTensorType>(operandValue.getType()) && operandSrcOp &&
+      operandSrcOp->hasTrait<OpTrait::ConstantLike>()) {
+    return;
+  }
+
+  Operation *operandOp = operand.getOwner();
   ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
 
   if (shardOp && sharding == shardOp.getSharding() &&
@@ -432,16 +454,14 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
                        ArrayRef<MeshAxesAttr> split_axes,
                        ArrayRef<MeshAxis> partial_axes,
                        mesh::ReductionKind partial_type,
-                       ArrayRef<int64_t> static_halo_sizes,
-                       ArrayRef<int64_t> static_sharded_dims_offsets) {
+                       ArrayRef<int64_t> static_halos,
+                       ArrayRef<int64_t> static_offsets) {
   return build(
       b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
       ::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes),
       ::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type),
-      ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halo_sizes), {},
-      ::mlir::DenseI64ArrayAttr::get(b.getContext(),
-                                     static_sharded_dims_offsets),
-      {});
+      ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
+      ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
 }
 
 void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
@@ -453,6 +473,18 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
       {}, {}, {}, {});
 }
 
+void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
+                       llvm::StringRef mesh, ArrayRef<MeshAxesAttr> split_axes,
+                       ArrayRef<int64_t> static_halos,
+                       ArrayRef<int64_t> static_offsets) {
+  return build(
+      b, odsState, FlatSymbolRefAttr::get(b.getContext(), mesh),
+      MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
+      ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
+      ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
+      ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
+}
+
 void ShardingOp::build(
     ::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
     FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes,
@@ -579,9 +611,10 @@ LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 namespace {
 // Sharding annotations "halo sizes" and "sharded dims offsets"
 // are a mix of attributes and dynamic values. This canonicalization moves
-// constant values to the respective attribute lists and so minimizes the number
+// constant values to the respective attribute lists, minimizing the number
 // of values.
-class FoldDynamicLists final : public OpRewritePattern<ShardingOp> {
+// It also removes sharded_dims_sizes and halos if they are effectively "empty".
+class NormalizeSharding final : public OpRewritePattern<ShardingOp> {
 public:
   using OpRewritePattern<ShardingOp>::OpRewritePattern;
 
@@ -593,14 +626,41 @@ class FoldDynamicLists final : public OpRewritePattern<ShardingOp> {
                                     op.getDynamicShardedDimsOffsets(), b);
 
     // No constant operands were folded, just return;
-    if (failed(foldDynamicIndexList(mixedHalos, /*onlyNonNegative=*/true)) &&
-        failed(foldDynamicIndexList(mixedOffs, /*onlyNo...
[truncated]

Copy link
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments on the first patch.

@fschlimb
Copy link
Contributor Author

Thanks @rengolin for your thorough review and comments. I made the modifications as you suggested.
Let me know if I can provide more clarification and/or changes.

@rengolin
Copy link
Member

Thanks @rengolin for your thorough review and comments. I made the modifications as you suggested. Let me know if I can provide more clarification and/or changes.

Thanks! So far so good. Just reviewed the final commits and left some comments. Should be good with those addressed.

Copy link
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you Frank, this looks really good!

@fschlimb fschlimb merged commit 0fd50ec into llvm:main Feb 12, 2025
8 checks passed
@llvm-ci
Copy link
Collaborator

llvm-ci commented Feb 12, 2025

LLVM Buildbot has detected a new failure on builder premerge-monolithic-linux running on premerge-linux-1 while building mlir at step 7 "test-build-unified-tree-check-all".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/153/builds/22645

Here is the relevant piece of the build log for the reference
Step 7 (test-build-unified-tree-check-all) failure: test (failure)
...
UNSUPPORTED: UBSan-Standalone-lld-x86_64 :: TestCases/Misc/no-interception.cpp (94250 of 98271)
PASS: UBSan-MemorySanitizer-x86_64 :: TestCases/Pointer/align-assume-attribute-assume_aligned-on-function-two-params.cpp (94251 of 98271)
PASS: UBSan-Standalone-lld-x86_64 :: TestCases/Misc/enum.cpp (94252 of 98271)
PASS: UBSan-Standalone-lld-x86_64 :: TestCases/Misc/abs.cpp (94253 of 98271)
UNSUPPORTED: UBSan-Standalone-lld-x86_64 :: TestCases/Misc/objc-cast.m (94254 of 98271)
PASS: UBSan-MemorySanitizer-x86_64 :: TestCases/ImplicitConversion/bitfield-conversion.c (94255 of 98271)
PASS: UBSan-Standalone-lld-x86_64 :: TestCases/Misc/bounds.cpp (94256 of 98271)
PASS: UBSan-Standalone-lld-x86_64 :: TestCases/Integer/uincdec-overflow.cpp (94257 of 98271)
PASS: UBSan-Standalone-lld-x86_64 :: TestCases/ImplicitConversion/integer-sign-change-ignorelist.c (94258 of 98271)
TIMEOUT: MLIR :: Examples/standalone/test.toy (94259 of 98271)
******************** TEST 'MLIR :: Examples/standalone/test.toy' FAILED ********************
Exit Code: 1
Timeout: Reached timeout of 60 seconds

Command Output (stdout):
--
# RUN: at line 1
"/etc/cmake/bin/cmake" "/build/buildbot/premerge-monolithic-linux/llvm-project/mlir/examples/standalone" -G "Ninja"  -DCMAKE_CXX_COMPILER=/usr/bin/clang++  -DCMAKE_C_COMPILER=/usr/bin/clang   -DLLVM_ENABLE_LIBCXX=OFF -DMLIR_DIR=/build/buildbot/premerge-monolithic-linux/build/lib/cmake/mlir  -DLLVM_USE_LINKER=lld  -DPython3_EXECUTABLE="/usr/bin/python3.10"
# executed command: /etc/cmake/bin/cmake /build/buildbot/premerge-monolithic-linux/llvm-project/mlir/examples/standalone -G Ninja -DCMAKE_CXX_COMPILER=/usr/bin/clang++ -DCMAKE_C_COMPILER=/usr/bin/clang -DLLVM_ENABLE_LIBCXX=OFF -DMLIR_DIR=/build/buildbot/premerge-monolithic-linux/build/lib/cmake/mlir -DLLVM_USE_LINKER=lld -DPython3_EXECUTABLE=/usr/bin/python3.10
# .---command stdout------------
# | -- The CXX compiler identification is Clang 16.0.6
# | -- The C compiler identification is Clang 16.0.6
# | -- Detecting CXX compiler ABI info
# | -- Detecting CXX compiler ABI info - done
# | -- Check for working CXX compiler: /usr/bin/clang++ - skipped
# | -- Detecting CXX compile features
# | -- Detecting CXX compile features - done
# | -- Detecting C compiler ABI info
# | -- Detecting C compiler ABI info - done
# | -- Check for working C compiler: /usr/bin/clang - skipped
# | -- Detecting C compile features
# | -- Detecting C compile features - done
# | -- Looking for histedit.h
# | -- Looking for histedit.h - found
# | -- Found LibEdit: /usr/include (found version "2.11") 
# | -- Found ZLIB: /usr/lib/x86_64-linux-gnu/libz.so (found version "1.2.11") 
# | -- Found LibXml2: /usr/lib/x86_64-linux-gnu/libxml2.so (found version "2.9.13") 
# | -- Using MLIRConfig.cmake in: /build/buildbot/premerge-monolithic-linux/build/lib/cmake/mlir
# | -- Using LLVMConfig.cmake in: /build/buildbot/premerge-monolithic-linux/build/lib/cmake/llvm
# | -- Linker detection: unknown
# | -- Performing Test LLVM_LIBSTDCXX_MIN
# | -- Performing Test LLVM_LIBSTDCXX_MIN - Success
# | -- Performing Test LLVM_LIBSTDCXX_SOFT_ERROR
# | -- Performing Test LLVM_LIBSTDCXX_SOFT_ERROR - Success
# | -- Performing Test CXX_SUPPORTS_CUSTOM_LINKER
# | -- Performing Test CXX_SUPPORTS_CUSTOM_LINKER - Success
# | -- Performing Test C_SUPPORTS_FPIC
# | -- Performing Test C_SUPPORTS_FPIC - Success
# | -- Performing Test CXX_SUPPORTS_FPIC

flovent pushed a commit to flovent/llvm-project that referenced this pull request Feb 13, 2025
A collection of fixes to the mesh dialect
- allow constants in sharding propagation/spmdization
- fixes to tensor replication (e.g. 0d tensors)
- improved canonicalization
- sharding propagation incorrectly generated too many ShardOps
New operation `mesh.GetShardOp` enables exchanging sharding information
(like on function boundaries)
joaosaffran pushed a commit to joaosaffran/llvm-project that referenced this pull request Feb 14, 2025
A collection of fixes to the mesh dialect
- allow constants in sharding propagation/spmdization
- fixes to tensor replication (e.g. 0d tensors)
- improved canonicalization
- sharding propagation incorrectly generated too many ShardOps
New operation `mesh.GetShardOp` enables exchanging sharding information
(like on function boundaries)
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
A collection of fixes to the mesh dialect
- allow constants in sharding propagation/spmdization
- fixes to tensor replication (e.g. 0d tensors)
- improved canonicalization
- sharding propagation incorrectly generated too many ShardOps
New operation `mesh.GetShardOp` enables exchanging sharding information
(like on function boundaries)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants