-
Notifications
You must be signed in to change notification settings - Fork 13.3k
Issues enabling ND mesh resharding in Spmdization pass: incorrect axis comparison and resharding assertion failures #136117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
@llvm/issue-subscribers-mlir Author: None (zhangdianchen)
When trying to enable ND mesh resharding in the Spmdization pass of MLIR, I encountered several issues that cause incorrect behavior or assertion failures. Below is a detailed breakdown:
### 1. Incorrect detection logic in detectMoveLastSplitAxisInResharding
**Problem Reproduction:**
When executing the following resharding sequence:
```
%sharding = mesh.sharding @mesh_3d split_axes = [[0, 1], [2]] : !mesh.sharding
%in1_sharded1 = mesh.shard %in1 to %sharding : tensor<8x16xi8>
%sharding = mesh.sharding @mesh_3d split_axes = [[0], [1, 2]] : !mesh.sharding
%in1_sharded2 = mesh.shard %in1_sharded1 to %sharding annotate_for_users : tensor<8x16xi8>
```
The pass is expected to detect a valid last-axis movement and insert a mesh.all_to_all operation. However, instead it crashes with the following assertion:
In the example [[0, 1], [2]] -> [[0], [1, 2]], this compares:
This incorrectly compares the wrong slices. In the example, it ends up comparing [0] and [1]. Instead, it should skip the first of the target and compare:
This now compares [0] with [2] — which is correct. 2.Incorrect ShardingTarget construction in targetShardingInMoveLastAxisSkipping the above assert leads to another failure: Root cause: in targetShardingInMoveLastAxis, the targetShardingSplitAxes are incorrectly ordered. Current result:
Fix: Instead of:
3. Bug in handlePartialAxesDuringReshardingIn the following snippet:
It should be writing to remainingPartialAxes, not allReduceGridAxes. Corrected version:
Please let me know if a patch is desired — I'm happy to contribute a PR for these changes. |
When trying to enable ND mesh resharding in the Spmdization pass of MLIR, I encountered several issues that cause incorrect behavior or assertion failures. Below is a detailed breakdown:
1. Incorrect detection logic in detectMoveLastSplitAxisInResharding
Problem Reproduction:
When executing the following resharding sequence:
The pass is expected to detect a valid last-axis movement and insert a mesh.all_to_all operation. However, instead it crashes with the following assertion:
mlir::TypedValue<mlir::ShapedType> mlir::sharding::reshardOn1DGrid(...): Assertion
targetShard && "Did not find any pattern to apply."' failed.`
Root Cause:
In detectMoveLastSplitAxisInResharding, the logic checks:
In the example [[0, 1], [2]] -> [[0], [1, 2]], this compares:
It should instead compare source.back() with target.front():
sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() != targetSharding.getSplitAxes()[targetTensorAxis].asArrayRef().front()
Additional incorrect check:
This incorrectly compares the wrong slices. In the example, it ends up comparing [0] and [1]. Instead, it should skip the first of the target and compare:
This now compares [0] with [2] — which is correct.
2.Incorrect ShardingTarget construction in targetShardingInMoveLastAxis
Skipping the above assert leads to another failure:
mlir::TypedValue<mlir::ShapedType> mlir::sharding::reshardOn1DGrid(...): Assertion
actualTargetSharding == targetSharding' failed.`Root cause: in targetShardingInMoveLastAxis, the targetShardingSplitAxes are incorrectly ordered.
Current result:
Fix: Instead of:
targetSplitAxes.push_back(gridAxis);
Use:
3. Bug in handlePartialAxesDuringResharding
In the following snippet:
It should be writing to remainingPartialAxes, not allReduceGridAxes. Corrected version:
Please let me know if a patch is desired — I'm happy to contribute a PR for these changes.
The text was updated successfully, but these errors were encountered: