Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] BlockEquivalenceData is wrong? #123375

Closed
makslevental opened this issue Jan 17, 2025 · 3 comments
Closed

[mlir] BlockEquivalenceData is wrong? #123375

makslevental opened this issue Jan 17, 2025 · 3 comments
Assignees
Labels
mlir question A question, not bug report. Check out https://llvm.org/docs/GettingInvolved.html instead!

Comments

@makslevental
Copy link
Contributor

  tt.func @condBranch(%arg0: !tt.ptr<f32>, %arg1: i1) -> tensor<1024xf32> {
    %c1024_i32 = arith.constant 1024 : i32
    %0 = tt.get_program_id x : i32
    %1 = arith.muli %0, %c1024_i32 : i32
    %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
    %3 = tt.splat %1 : i32 -> tensor<1024xi32>
    %4 = arith.addi %3, %2 : tensor<1024xi32>
    %5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
    %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
    cf.cond_br %arg1, ^bb1(%5 : tensor<1024x!tt.ptr<f32>>), ^bb2(%6 : tensor<1024x!tt.ptr<f32>>)
  ^bb1(%7: tensor<1024x!tt.ptr<f32>>):  // pred: ^bb0
    %8 = tt.load %7 : tensor<1024x!tt.ptr<f32>>
    tt.return %8 : tensor<1024xf32>
  ^bb2(%9: tensor<1024x!tt.ptr<f32>>):  // pred: ^bb0
    %10 = tt.load %9 : tensor<1024x!tt.ptr<f32>>
    tt.return %10 : tensor<1024xf32>
  }

mlir::simplifyRegions gives

  tt.func @condBranch(%arg0: !tt.ptr<f32>, %arg1: i1) -> tensor<1024xf32> {
    %c0_i64 = arith.constant 0 : i64
    %0 = builtin.unrealized_conversion_cast %arg0, %c0_i64 : !tt.ptr<f32>, i64 to !tt.ptr<f32>
    %c1024_i32 = arith.constant 1024 : i32
    %1 = tt.get_program_id x : i32
    %2 = arith.muli %1, %c1024_i32 : i32
    %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
    %4 = tt.splat %2 : i32 -> tensor<1024xi32>
    %5 = arith.addi %4, %3 : tensor<1024xi32>
    %6 = tt.splat %0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
    %7 = tt.addptr %6, %5 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
    cf.cond_br %arg1, ^bb1(%6 : tensor<1024x!tt.ptr<f32>>), ^bb1(%7 : tensor<1024x!tt.ptr<f32>>)
  ^bb1(%8: tensor<1024x!tt.ptr<f32>>):  // 2 preds: ^bb0, ^bb0
    %9 = tt.load %8 : tensor<1024x!tt.ptr<f32>>
    tt.return %9 : tensor<1024xf32>
  }

because

/// This class contains the information for comparing the equivalencies of two
/// blocks. Blocks are considered equivalent if they contain the same operations
/// in the same order. The only allowed divergence is for operands that come
/// from sources outside of the parent block, i.e. the uses of values produced
/// within the block must be equivalent.

I don't understand how that's a legal merge/rewrite/change?

@llvmbot
Copy link
Member

llvmbot commented Jan 17, 2025

@llvm/issue-subscribers-mlir

Author: Maksim Levental (makslevental)

```mlir tt.func @condBranch(%arg0: !tt.ptr<f32>, %arg1: i1) -> tensor<1024xf32> { %c1024_i32 = arith.constant 1024 : i32 %0 = tt.get_program_id x : i32 %1 = arith.muli %0, %c1024_i32 : i32 %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> %3 = tt.splat %1 : i32 -> tensor<1024xi32> %4 = arith.addi %3, %2 : tensor<1024xi32> %5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> cf.cond_br %arg1, ^bb1(%5 : tensor<1024x!tt.ptr<f32>>), ^bb2(%6 : tensor<1024x!tt.ptr<f32>>) ^bb1(%7: tensor<1024x!tt.ptr<f32>>): // pred: ^bb0 %8 = tt.load %7 : tensor<1024x!tt.ptr<f32>> tt.return %8 : tensor<1024xf32> ^bb2(%9: tensor<1024x!tt.ptr<f32>>): // pred: ^bb0 %10 = tt.load %9 : tensor<1024x!tt.ptr<f32>> tt.return %10 : tensor<1024xf32> } ```

mlir::simplifyRegions gives

  tt.func @<!-- -->condBranch(%arg0: !tt.ptr&lt;f32&gt;, %arg1: i1) -&gt; tensor&lt;1024xf32&gt; {
    %c0_i64 = arith.constant 0 : i64
    %0 = builtin.unrealized_conversion_cast %arg0, %c0_i64 : !tt.ptr&lt;f32&gt;, i64 to !tt.ptr&lt;f32&gt;
    %c1024_i32 = arith.constant 1024 : i32
    %1 = tt.get_program_id x : i32
    %2 = arith.muli %1, %c1024_i32 : i32
    %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor&lt;1024xi32&gt;
    %4 = tt.splat %2 : i32 -&gt; tensor&lt;1024xi32&gt;
    %5 = arith.addi %4, %3 : tensor&lt;1024xi32&gt;
    %6 = tt.splat %0 : !tt.ptr&lt;f32&gt; -&gt; tensor&lt;1024x!tt.ptr&lt;f32&gt;&gt;
    %7 = tt.addptr %6, %5 : tensor&lt;1024x!tt.ptr&lt;f32&gt;&gt;, tensor&lt;1024xi32&gt;
    cf.cond_br %arg1, ^bb1(%6 : tensor&lt;1024x!tt.ptr&lt;f32&gt;&gt;), ^bb1(%7 : tensor&lt;1024x!tt.ptr&lt;f32&gt;&gt;)
  ^bb1(%8: tensor&lt;1024x!tt.ptr&lt;f32&gt;&gt;):  // 2 preds: ^bb0, ^bb0
    %9 = tt.load %8 : tensor&lt;1024x!tt.ptr&lt;f32&gt;&gt;
    tt.return %9 : tensor&lt;1024xf32&gt;
  }

because

> /// This class contains the information for comparing the equivalencies of two
> /// blocks. Blocks are considered equivalent if they contain the same operations
> /// in the same order. The only allowed divergence is for operands that come
> /// from sources outside of the parent block, i.e. the uses of values produced
> /// within the block must be equivalent.

I don't understand how that's a legal merge/rewrite/change?

@jpienaar
Copy link
Member

Could you expand? bb1 and bb2 in the original 1) loaded from bbarg 0, 2) returned the value. Isn't this what is true post simplification too? (I'm ignoring the unrealized case there as I think you are referring of going to 2 blocks from 3). Is there something observable that would result in that check being invalid that you are thinking of?

@makslevental
Copy link
Contributor Author

Ya I'm blind I don't know how I missed that the new cond_br still passes both its operands correctly

cf.cond_br %arg1, ^bb1(%6 : tensor<1024x!tt.ptr<f32>>), ^bb1(%7 : tensor<1024x!tt.ptr<f32>>)

@EugeneZelenko EugeneZelenko added the question A question, not bug report. Check out https://llvm.org/docs/GettingInvolved.html instead! label Jan 18, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir question A question, not bug report. Check out https://llvm.org/docs/GettingInvolved.html instead!
Projects
None yet
Development

No branches or pull requests

4 participants