Skip to content

Issues enabling ND mesh resharding in Spmdization pass: incorrect axis comparison and resharding assertion failures #136117

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
zhangdianchen opened this issue Apr 17, 2025 · 1 comment
Labels

Comments

@zhangdianchen
Copy link

When trying to enable ND mesh resharding in the Spmdization pass of MLIR, I encountered several issues that cause incorrect behavior or assertion failures. Below is a detailed breakdown:

1. Incorrect detection logic in detectMoveLastSplitAxisInResharding

Problem Reproduction:
When executing the following resharding sequence:

%sharding = mesh.sharding @mesh_3d split_axes = [[0, 1], [2]]  : !mesh.sharding
%in1_sharded1 = mesh.shard %in1 to %sharding  : tensor<8x16xi8>
%sharding = mesh.sharding @mesh_3d split_axes = [[0], [1, 2]]  : !mesh.sharding
%in1_sharded2 = mesh.shard %in1_sharded1 to %sharding annotate_for_users : tensor<8x16xi8>

The pass is expected to detect a valid last-axis movement and insert a mesh.all_to_all operation. However, instead it crashes with the following assertion:
mlir::TypedValue<mlir::ShapedType> mlir::sharding::reshardOn1DGrid(...): Assertion targetShard && "Did not find any pattern to apply."' failed.
`
Root Cause:
In detectMoveLastSplitAxisInResharding, the logic checks:

if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() ||
    targetSharding.getSplitAxes()[targetTensorAxis].empty() ||
    sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
        targetSharding.getSplitAxes()[targetTensorAxis].asArrayRef().back())
  continue;

In the example [[0, 1], [2]] -> [[0], [1, 2]], this compares:

  • source.split_axes[0][1] = 1 vs. target.split_axes[1][1] = 2 — incorrect
    It should instead compare source.back() with target.front():
    sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() != targetSharding.getSplitAxes()[targetTensorAxis].asArrayRef().front()
    Additional incorrect check:
if (!llvm::equal(
      llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().begin(),
                       sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().end() - 1),
      llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis].asArrayRef().begin(),
                       targetSharding.getSplitAxes()[targetTensorAxis].asArrayRef().end() - 1)))
  continue;

This incorrectly compares the wrong slices. In the example, it ends up comparing [0] and [1]. Instead, it should skip the first of the target and compare:

if (llvm::equal(
      llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().begin(),
                       sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().end() - 1),
      llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis].asArrayRef().begin() + 1,
                       targetSharding.getSplitAxes()[targetTensorAxis].asArrayRef().end())))
  continue;

This now compares [0] with [2] — which is correct.

2.Incorrect ShardingTarget construction in targetShardingInMoveLastAxis

Skipping the above assert leads to another failure:
mlir::TypedValue<mlir::ShapedType> mlir::sharding::reshardOn1DGrid(...): Assertion actualTargetSharding == targetSharding' failed.`

Root cause: in targetShardingInMoveLastAxis, the targetShardingSplitAxes are incorrectly ordered.

Current result:

actualTargetSharding: split_axes = [[0], [2, 1]]
targetSharding:       split_axes = [[0], [1, 2]]

Fix: Instead of:
targetSplitAxes.push_back(gridAxis);
Use:

targetSplitAxes.insert(targetSplitAxes.begin(), gridAxis);

3. Bug in handlePartialAxesDuringResharding

In the following snippet:

llvm::SmallVector<GridAxis> remainingPartialAxes;
llvm::copy_if(sourceShardingPartialAxesSet,
              std::back_inserter(allReduceGridAxes),
              [&targetShardingPartialAxesSet](Axis a) {
                return targetShardingPartialAxesSet.contains(a);
              });

It should be writing to remainingPartialAxes, not allReduceGridAxes. Corrected version:

llvm::copy_if(sourceShardingPartialAxesSet,
              std::back_inserter(remainingPartialAxes),
              [&targetShardingPartialAxesSet](Axis a) {
                return targetShardingPartialAxesSet.contains(a);
              });

Please let me know if a patch is desired — I'm happy to contribute a PR for these changes.

@llvmbot
Copy link
Member

llvmbot commented Apr 17, 2025

@llvm/issue-subscribers-mlir

Author: None (zhangdianchen)

When trying to enable ND mesh resharding in the Spmdization pass of MLIR, I encountered several issues that cause incorrect behavior or assertion failures. Below is a detailed breakdown: ### 1. Incorrect detection logic in detectMoveLastSplitAxisInResharding **Problem Reproduction:** When executing the following resharding sequence: ``` %sharding = mesh.sharding @mesh_3d split_axes = [[0, 1], [2]] : !mesh.sharding %in1_sharded1 = mesh.shard %in1 to %sharding : tensor<8x16xi8> %sharding = mesh.sharding @mesh_3d split_axes = [[0], [1, 2]] : !mesh.sharding %in1_sharded2 = mesh.shard %in1_sharded1 to %sharding annotate_for_users : tensor<8x16xi8> ```

The pass is expected to detect a valid last-axis movement and insert a mesh.all_to_all operation. However, instead it crashes with the following assertion:
mlir::TypedValue&lt;mlir::ShapedType&gt; mlir::sharding::reshardOn1DGrid(...): Assertion targetShard && "Did not find any pattern to apply."' failed.
`
Root Cause:
In detectMoveLastSplitAxisInResharding, the logic checks:

if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() ||
    targetSharding.getSplitAxes()[targetTensorAxis].empty() ||
    sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
        targetSharding.getSplitAxes()[targetTensorAxis].asArrayRef().back())
  continue;

In the example [[0, 1], [2]] -> [[0], [1, 2]], this compares:

  • source.split_axes[0][1] = 1 vs. target.split_axes[1][1] = 2 — incorrect
    It should instead compare source.back() with target.front():
    sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() != targetSharding.getSplitAxes()[targetTensorAxis].asArrayRef().front()
    Additional incorrect check:
if (!llvm::equal(
      llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().begin(),
                       sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().end() - 1),
      llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis].asArrayRef().begin(),
                       targetSharding.getSplitAxes()[targetTensorAxis].asArrayRef().end() - 1)))
  continue;

This incorrectly compares the wrong slices. In the example, it ends up comparing [0] and [1]. Instead, it should skip the first of the target and compare:

if (llvm::equal(
      llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().begin(),
                       sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().end() - 1),
      llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis].asArrayRef().begin() + 1,
                       targetSharding.getSplitAxes()[targetTensorAxis].asArrayRef().end())))
  continue;

This now compares [0] with [2] — which is correct.

2.Incorrect ShardingTarget construction in targetShardingInMoveLastAxis

Skipping the above assert leads to another failure:
mlir::TypedValue&lt;mlir::ShapedType&gt; mlir::sharding::reshardOn1DGrid(...): Assertion actualTargetSharding == targetSharding' failed.`

Root cause: in targetShardingInMoveLastAxis, the targetShardingSplitAxes are incorrectly ordered.

Current result:

actualTargetSharding: split_axes = [[0], [2, 1]]
targetSharding:       split_axes = [[0], [1, 2]]

Fix: Instead of:
targetSplitAxes.push_back(gridAxis);
Use:

targetSplitAxes.insert(targetSplitAxes.begin(), gridAxis);

3. Bug in handlePartialAxesDuringResharding

In the following snippet:

llvm::SmallVector&lt;GridAxis&gt; remainingPartialAxes;
llvm::copy_if(sourceShardingPartialAxesSet,
              std::back_inserter(allReduceGridAxes),
              [&amp;targetShardingPartialAxesSet](Axis a) {
                return targetShardingPartialAxesSet.contains(a);
              });

It should be writing to remainingPartialAxes, not allReduceGridAxes. Corrected version:

llvm::copy_if(sourceShardingPartialAxesSet,
              std::back_inserter(remainingPartialAxes),
              [&amp;targetShardingPartialAxesSet](Axis a) {
                return targetShardingPartialAxesSet.contains(a);
              });

Please let me know if a patch is desired — I'm happy to contribute a PR for these changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants