Skip to content

Commit

Permalink
comments and nicer code (from review)
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlimb committed Feb 12, 2025
1 parent 5702b45 commit 405e4b0
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 18 deletions.
35 changes: 19 additions & 16 deletions mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -629,30 +629,33 @@ class NormalizeSharding final : public OpRewritePattern<ShardingOp> {
bool modified = succeeded(foldDynamicIndexList(mixedHalos, true)) ||
succeeded(foldDynamicIndexList(mixedOffs, true));

auto halos = decomposeMixedValues(mixedHalos);
auto offs = decomposeMixedValues(mixedOffs);
auto [staticHalos, dynamicHalos] = decomposeMixedValues(mixedHalos);
auto [staticOffs, dynamicOffs] = decomposeMixedValues(mixedOffs);

if (halos.second.empty() && !halos.first.empty()) {
if (halos.first[0] == 0 && llvm::all_equal(halos.first)) {
halos.first.clear();
if (dynamicHalos.empty() && !staticHalos.empty()) {
if (staticHalos[0] == 0 && llvm::all_equal(staticHalos)) {
staticHalos.clear();
modified = true;
}
}

// Remove sharded dims offsets if they are effectively the default values,
// e.g. if they define equi-distance between all neighboring shards.
if (offs.second.empty() && !offs.first.empty()) {
assert(offs.first.size() >= 2);
auto diff = offs.first[1] - offs.first[0];
bool all_same = offs.first.size() > 2;
for (auto i = 2u; i < offs.first.size(); ++i) {
if (offs.first[i] - offs.first[i - 1] != diff) {
// Requires static-only offsets. Compares the first distance as the
// difference between the first two offsets. Only if all consecutive
// distances are the same, the offsets are removed.
if (dynamicOffs.empty() && !staticOffs.empty()) {
assert(staticOffs.size() >= 2);
auto diff = staticOffs[1] - staticOffs[0];
bool all_same = staticOffs.size() > 2;
for (auto i = 2u; i < staticOffs.size(); ++i) {
if (staticOffs[i] - staticOffs[i - 1] != diff) {
all_same = false;
break;
}
}
if (all_same) {
offs.first.clear();
staticOffs.clear();
modified = true;
}
}
Expand All @@ -661,10 +664,10 @@ class NormalizeSharding final : public OpRewritePattern<ShardingOp> {
return failure();
}

op.setStaticHaloSizes(halos.first);
op.getDynamicHaloSizesMutable().assign(halos.second);
op.setStaticShardedDimsOffsets(offs.first);
op.getDynamicShardedDimsOffsetsMutable().assign(offs.second);
op.setStaticHaloSizes(staticHalos);
op.getDynamicHaloSizesMutable().assign(dynamicHalos);
op.setStaticShardedDimsOffsets(staticOffs);
op.getDynamicShardedDimsOffsetsMutable().assign(dynamicOffs);

return success();
}
Expand Down
2 changes: 0 additions & 2 deletions mlir/test/Dialect/Mesh/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,6 @@ func.func @mesh_shard_shape() {
// CHECK-LABEL: func @mesh_get_sharding
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
func.func @mesh_get_sharding(%arg0 : tensor<4x8xf32>) -> !mesh.sharding {
// CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh1 split_axes = {{\[\[}}], [0]] : !mesh.sharding
%s = mesh.sharding @mesh1 split_axes = [[], [0]] : !mesh.sharding
// CHECK-NEXT: mesh.get_sharding %[[ARG]] : tensor<4x8xf32> -> !mesh.sharding
%0 = mesh.get_sharding %arg0 : tensor<4x8xf32> -> !mesh.sharding
return %0 : !mesh.sharding
Expand Down

0 comments on commit 405e4b0

Please sign in to comment.