You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In trying the workaround I found for #381 of doing a degenerate distribution along the problematic dimension, I found an issue in reduction expansion. I'm not sure precisely of the bounds that trigger the bug, but it happens when their are two init args/return vals from mma outputs of a reduction and the first one gets expanded along an unrelated dimension. The second return value/init arg is then dropped and replaced with the second part of the expansion of the first return value. Here's a repro Python test:
I think the issue is in compute_result_index and has to do with the results getting expanded along different dimensions. Which I think comes from dim_scaling not including the N dimension, which based on get_dim_scaling, looks to be because N isn't in the vector shaps for the reduction. Setting it explicitly in the hardware constraints seems to fix the issue, but it seems like this shouldn't be necessary (and neither should the degenerate distribution)
In trying the workaround I found for #381 of doing a degenerate distribution along the problematic dimension, I found an issue in reduction expansion. I'm not sure precisely of the bounds that trigger the bug, but it happens when their are two init args/return vals from mma outputs of a reduction and the first one gets expanded along an unrelated dimension. The second return value/init arg is then dropped and replaced with the second part of the expansion of the first return value. Here's a repro Python test:
Python test repro
which results in the following IR before expand_graph
IR before expand graph
Right before the
fixup_reduction_nodes
call in expansion, the IR looks like this. I think at this point the metadata has already been messed up thoughBefore `fixup_reduction_nodes`
because then after
fixup_reduction_nodes
we get thisAfter `fixup_reduction_nodes`
The init args related to
register_1
have been lost. This then results in an error further along when emitting the MLIR.FYI @harsh-nod
The text was updated successfully, but these errors were encountered: