Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SCFToCalyx] Issue with lowering scf::parallel to Calyx when there is scf::if in it #8086

Closed
jiahanxie353 opened this issue Jan 15, 2025 · 9 comments
Assignees
Labels
Calyx The Calyx dialect

Comments

@jiahanxie353
Copy link
Contributor

We lower scf::parallel in SCFToCalyx by assuming constant loop bounds and strides, which enables us to unroll the parallel loop manually: #7830

The issue is that, if scf::parallel contains scf::if, whose condition variable depends on some constants and scf::parallel's loop induction variables, those condition variables will be eventually evaluated to constant true/false automatically even if we never implement this sort of "canonicalization" pass manually in SCFToCalyx.cpp (I guess it's done by some underling MLIR mechanisms?).

For example, if we begin with:

func.func() {
    %c0 = arith.constant 0 : index
    scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
      %1 = arith.remsi %arg2, %c2 : index
      %2 = arith.cmpi slt, %1, %c0 : index
      %3 = scf.if %2 -> (f32) {
        scf.yield ...
      }
    }
}

will turn into:

func.func() {
    %c0 = arith.constant 0 : index
    scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
    ^bb0:
      %1 = arith.remsi "value 0", %c2 : index
      %2 = arith.cmpi slt, %1, %c0 : index
      %3 = scf.if %2 -> (f32) {
        scf.yield ...
      }
    ^bb1:
      %1 = arith.remsi "value 1", %c2 : index
      %2 = arith.cmpi slt, %1, %c0 : index
      %3 = scf.if %2 -> (f32) {
        scf.yield ...
      }
    ^bb2:
    ...
    }
}

And in each basic block:

%1 = arith.remsi "value 0", %c2 : index
%2 = arith.cmpi slt, %1, %c0 : index

will be folded to constants automatically, turning the scf::if op to:

%3 = scf.if true/false {
 ...
}

If they were to be lowered to Calyx, it'll become:

calyx.if 1'd1/1'd0 {
...
}

but it's not allowed by Calyx' grammar: https://github.com/calyxir/calyx/blob/2e18d555ff00339f46c99cc9d8b698ee27ce25ad/calyx-frontend/src/syntax.pest#L348

Any idea on how to tackle this? @rachitnigam @cgyurgyik @andrewb1999 @mikeurbach Thanks!

@jiahanxie353 jiahanxie353 self-assigned this Jan 15, 2025
@jiahanxie353 jiahanxie353 added the Calyx The Calyx dialect label Jan 15, 2025
@cgyurgyik
Copy link
Member

Can't we just simplify scf.if true { body } => body?

@jiahanxie353
Copy link
Contributor Author

Can't we just simplify scf.if true { body } => body?

We could, I just thought that it should already exist in a canonicalization pass for the SCF dialect. I didn't want to introduce duplicated code into the codebase; but on the other hand, I felt like that we have to. So I wanted to bring this up.

@cgyurgyik
Copy link
Member

This looks like it is simplified after canonicalization: https://godbolt.org/z/j833qxMEh

@jiahanxie353
Copy link
Contributor Author

This looks like it is simplified after canonicalization: https://godbolt.org/z/j833qxMEh

Indeed, the issue actually has to do with the way we are lowering scf::parallel. We created multiple blocks when lowering to Calyx: #7830. But this is not canonical - scf::parallel expects:

The body region must contain exactly one block that terminates with a scf.reduce operation.

(source: https://mlir.llvm.org/docs/Dialects/SCFDialect/#scfparallel-scfparallelop)

So as a result of running the canonicalization, all blocks except the first block within the region of scf::parallel get erased by the canonicalization pass..

@jiahanxie353
Copy link
Contributor Author

This looks like it is simplified after canonicalization: https://godbolt.org/z/j833qxMEh

Indeed, the issue actually has to do with the way we are lowering scf::parallel. We created multiple blocks when lowering to Calyx: #7830. But this is not canonical - scf::parallel expects:

The body region must contain exactly one block that terminates with a scf.reduce operation.

(source: https://mlir.llvm.org/docs/Dialects/SCFDialect/#scfparallel-scfparallelop)

So as a result of running the canonicalization, all blocks except the first block within the region of scf::parallel get erased by the canonicalization pass..

I found a potential work out to avoid the pass to erase the blocks within the newly created scf::parallel - wrap each block with scf::execute_region: https://mlir.llvm.org/docs/Dialects/SCFDialect/#scfexecute_region-scfexecuteregionop. Trying it out

@jiahanxie353
Copy link
Contributor Author

I found a potential work out to avoid the pass to erase the blocks within the newly created scf::parallel - wrap each block with scf::execute_region: https://mlir.llvm.org/docs/Dialects/SCFDialect/#scfexecute_region-scfexecuteregionop. Trying it out

No it doesn't work very well - the execute_region ops get inlined then eliminated by the canonicalization pass as well...

@cgyurgyik
Copy link
Member

This seems to be a different issue than you originally stated in the first comment. Checking my understanding, at some point in SCFToCalyx, an scf::ParOp's body is turned from a legal representation (single block) to an illegal representation (> 1 blocks). Then, when you canonicalize to try and get rid of the static condition in scf::IfOp, these additional blocks are incorrectly removed. Is this correct?

If so, I'm surprised this doesn't throw an error instead of just blindly deleting blocks. Is it to possible to keep a single block when lowering scf::ParOp, or avoid canonicalizing when it is in an illegal state?

@jiahanxie353
Copy link
Contributor Author

This seems to be a different issue than you originally stated in the first comment. Checking my understanding, at some point in SCFToCalyx, an scf::ParOp's body is turned from a legal representation (single block) to an illegal representation (> 1 blocks). Then, when you canonicalize to try and get rid of the static condition in scf::IfOp, these additional blocks are incorrectly removed. Is this correct?

If so, I'm surprised this doesn't throw an error instead of just blindly deleting blocks. Is it to possible to keep a single block when lowering scf::ParOp, or avoid canonicalizing when it is in an illegal state?

Your understanding is totally correct. I'm also surprised by that.

Is it to possible to keep a single block when lowering scf::ParOp

Yes, I just had it locally and I plan to make a PR to wrap the blocks inside execute_region operations, one execute_region per block.

Apart from that, I just found out that I can run an op-specific canonicalization pass. That is, the --canonicalize pass in MLIR canonicalizes everything; but since my goal is to canonicalize scf::if, I'll just invoke scf::if-specific canonicalization pass, and I just found it: https://github.com/llvm/llvm-project/blob/1434313bd8c425b2aadc301ddaf42a91552e609e/mlir/lib/Dialect/SCF/IR/SCF.cpp#L2811

@jiahanxie353
Copy link
Contributor Author

closed by #8098 and #8103

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Calyx The Calyx dialect
Projects
None yet
Development

No branches or pull requests

2 participants