Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TKW] Bug: reduction expansion loses return values #384

Open
GMNGeoffrey opened this issue Jan 13, 2025 · 2 comments
Open

[TKW] Bug: reduction expansion loses return values #384

GMNGeoffrey opened this issue Jan 13, 2025 · 2 comments

Comments

@GMNGeoffrey
Copy link
Contributor

GMNGeoffrey commented Jan 13, 2025

In trying the workaround I found for #381 of doing a degenerate distribution along the problematic dimension, I found an issue in reduction expansion. I'm not sure precisely of the bounds that trigger the bug, but it happens when their are two init args/return vals from mma outputs of a reduction and the first one gets expanded along an unrelated dimension. The second return value/init arg is then dropped and replaced with the second part of the expansion of the first return value. Here's a repro Python test:

Python test repro
def testReproExpansionOrthogonalToReduction(shape):
    # shape = (1, 16, 32, 16, 16)
    _, q_seq_len, v_head_dim, qk_head_dim, kv_seq_len = shape
    mfma_variant = MMAType.F32_16x16x16_F16
    # Input sizes
    M = tkl.sym.M  # query sequence length
    N = tkl.sym.N  # value head dimension
    K1 = tkl.sym.K1  # query/key head dimension
    K2 = tkl.sym.K2  # key/value sequence length

    LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD
    STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD

    # Expose user-constraints
    constraints: list[tkw.Constraint] = [
        tkw.WorkgroupConstraint(K2, K2, 0),
        tkw.WorkgroupConstraint(N, N, 1),
        tkw.TilingConstraint(M, M),
        tkw.HardwareConstraint(
            threads_per_wave=64,
            waves_per_block=(1, 1, 1),
            mma_type=mfma_variant,
        )
    ]

    i = tkw.IndexMapping.iterator(0)
    j = tkw.IndexMapping.iterator(1)

    flip_m_k1_read_mapping = tkw.IndexMapping(
        num_iterators=2, inputs={M: j, K1: i}, outputs={K1: i, M: j},
    )

    @tkw.wave(constraints)
    def attention_bwd(
        q: tkl.Memory[M, K1, GLOBAL_ADDRESS_SPACE, tkl.f16],
        k: tkl.Memory[K2, K1, GLOBAL_ADDRESS_SPACE, tkl.f16],
        do: tkl.Memory[N, M, GLOBAL_ADDRESS_SPACE, tkl.f16],
        dk: tkl.Memory[K2, K1, GLOBAL_ADDRESS_SPACE, tkl.f32],
        dv: tkl.Memory[K2, N, GLOBAL_ADDRESS_SPACE, tkl.f32],
        ds: tkl.Memory[K2, M, GLOBAL_ADDRESS_SPACE, tkl.f16],
    ):

        dv_init = tkl.Register[K2, N, tkl.f32](0.0)
        dk_init = tkl.Register[K2, K1, tkl.f32](0.0)

        @tkw.reduction(M, init_args=[dv_init, dk_init])
        def loop_q_seq_len(
            dv_prev: tkl.Register[K2, N, tkl.f32],
            dk_prev: tkl.Register[K2, K1, tkl.f32],
        ):
            k_j = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD)
            q_i = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD)

            s_acc = tkl.Register[M, K2, tkl.f32](0.0)
            s_ij = tkw.mma(q_i, k_j, s_acc)
            s_ij = tkw.permute(s_ij, [K2, M])

            do_i = tkw.read(do, elements_per_thread=LOAD_ELEMS_PER_THREAD)
            dv_j = tkw.mma(tkw.cast(s_ij, tkl.f16), do_i, dv_prev)
            
            ds_ij = tkw.read(ds, elements_per_thread=LOAD_ELEMS_PER_THREAD)
            q_i_for_dk = tkw.read(q, mapping=flip_m_k1_read_mapping, elements_per_thread=LOAD_ELEMS_PER_THREAD)
            dk_j = tkw.mma(ds_ij, q_i_for_dk, dk_prev)
            return (dv_j, dk_j)

        (dv_j, dk_j) = loop_q_seq_len
        tkw.write(dv_j, dv, elements_per_thread=STORE_ELEMS_PER_THREAD)
        tkw.write(dk_j, dk, elements_per_thread=STORE_ELEMS_PER_THREAD)

    hyperparams = {
        LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant),
        STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant),
        M: q_seq_len,
        N: v_head_dim,
        K1: qk_head_dim,
        K2: kv_seq_len,
    }


    hyperparams.update(get_default_scheduling_params())
    config = get_default_run_config()
    # config["print_ir_after_all"] = True
    compile_config = {
        "waves_per_eu": 2,
        "denorm_fp_math_f32": "preserve-sign",
        "print_ir_after": ["expand_graph"],
        "print_ir_before": ["expand_graph"],
        # "print_signature": True,
        # "print_pretty_mlir": True,
        # "print_indices": True,
    }

    with tk.gen.TestLaunchContext(
        hyperparams,
        canonicalize=True,
        run=True,
        run_bench=False,
        run_config=config,
        compile_config=compile_config,
        schedule=False,
        use_scheduling_barriers=enable_scheduling_barriers,
    ):
        
        torch.manual_seed(0)
        q = device_randn(q_seq_len, qk_head_dim, dtype=torch.float16)
        k = device_randn(kv_seq_len, qk_head_dim, dtype=torch.float16)
        do = device_randn(q_seq_len, v_head_dim, dtype=torch.float16)
        ds = device_randn(kv_seq_len, q_seq_len, dtype=torch.float16)

        s_ref = torch.matmul(q, k.transpose(-1, -2))
        dv_ref = torch.matmul(s_ref.transpose(-1, -2), do)

        dk = device_zeros(kv_seq_len, qk_head_dim, dtype=torch.float32)
        dv = device_zeros(kv_seq_len, v_head_dim, dtype=torch.float32)

        mb_bwd = attention_bwd(
            q,
            k,
            do.transpose(-1, -2),
            dk,
            dv,
            ds,
        )

        assert_close(dv, dv_ref.to(torch.float32), atol=1e-3, rtol=1e-3)

which results in the following IR before expand_graph

IR before expand graph
Before expand_graph:
region_0:
graph():
    %dv_prev :  [num_users=1] = placeholder[target=dv_prev]
    %dk_prev :  [num_users=1] = placeholder[target=dk_prev]
    %k :  [num_users=1] = placeholder[target=k]
    %read :  [num_users=1] = [read](args = (%k, LOAD_ELEMS_PER_THREAD, None, (), None), kwargs = {})
    %q :  [num_users=2] = placeholder[target=q]
    %read_1 :  [num_users=1] = [read](args = (%q, LOAD_ELEMS_PER_THREAD, None, (), None), kwargs = {})
    %register :  [num_users=1] = [register](args = ((M, K2), f32, 0.0), kwargs = {})
    %mma :  [num_users=1] = [mma](args = (%read_1, %read, %register, None), kwargs = {})
    %permute :  [num_users=1] = [permute](args = (%mma, [K2, M]), kwargs = {})
    %do :  [num_users=1] = placeholder[target=do]
    %read_2 :  [num_users=1] = [read](args = (%do, LOAD_ELEMS_PER_THREAD, None, (), None), kwargs = {})
    %cast :  [num_users=1] = [cast](args = (%permute, f16), kwargs = {})
    %mma_1 :  [num_users=1] = [mma](args = (%cast, %read_2, %dv_prev, None), kwargs = {})
    %ds :  [num_users=1] = placeholder[target=ds]
    %read_3 :  [num_users=1] = [read](args = (%ds, LOAD_ELEMS_PER_THREAD, None, (), None), kwargs = {})
    %read_4 :  [num_users=1] = [read](args = (%q, LOAD_ELEMS_PER_THREAD, IndexMapping(iters={$index0: 0, $index1: 1}, input_mapping={M: $index1, K1: $index0}), output_mapping={K1: $index0, M: $index1}, dynamic_val_mappings=(), (), None), kwargs = {})
    %mma_2 :  [num_users=1] = [mma](args = (%read_3, %read_4, %dk_prev, None), kwargs = {})
    return (mma_1, mma_2)
region_1 [root]:
graph():
    %q :  [num_users=1] = placeholder[target=q]
    %k :  [num_users=1] = placeholder[target=k]
    %do :  [num_users=1] = placeholder[target=do]
    %dk :  [num_users=1] = placeholder[target=dk]
    %dv :  [num_users=1] = placeholder[target=dv]
    %ds :  [num_users=1] = placeholder[target=ds]
    %register :  [num_users=1] = [register](args = ((K2, N), f32, 0.0), kwargs = {})
    %register_1 :  [num_users=1] = [register](args = ((K2, K1), f32, 0.0), kwargs = {})
    %reduction : [Register[K2, N].of(f32), Register[K2, K1].of(f32)] [num_users=2] = [reduction](args = (M, [%register, %register_1], region_0, [%k, %q, %do, %ds]), kwargs = {})
    %getitem :  [num_users=1] = [getitem](args = (%reduction, 0), kwargs = {})
    %getitem_1 :  [num_users=1] = [getitem](args = (%reduction, 1), kwargs = {})
    %write :  [num_users=0] = [write](args = (%getitem, %dv, STORE_ELEMS_PER_THREAD, None, ()), kwargs = {})
    %write_1 :  [num_users=0] = [write](args = (%getitem_1, %dk, STORE_ELEMS_PER_THREAD, None, ()), kwargs = {})
    return None

Right before the fixup_reduction_nodes call in expansion, the IR looks like this. I think at this point the metadata has already been messed up though

Before `fixup_reduction_nodes`
region_0:
graph():
    %dv_K2:0_N:0_M:0 :  [num_users=1] = placeholder[target=dv_prev]
    %dv_K2:0_N:1_M:0 :  [num_users=1] = placeholder[target=dv_prev]
    %dv_prev :  [num_users=1] = placeholder[target=dv_prev]
    %dk_K2:0_M:0 :  [num_users=1] = placeholder[target=dk_prev]
    %dk_prev :  [num_users=1] = placeholder[target=dk_prev]
    %k :  [num_users=2] = placeholder[target=k]
    %read_K2:0_N:0_M:0_K1:0 :  [num_users=1] = [read](args = (%k, LOAD_ELEMS_PER_THREAD, None, (), None), kwargs = {})
    %read :  [num_users=1] = [read](args = (%k, LOAD_ELEMS_PER_THREAD, None, (), None), kwargs = {})
    %q :  [num_users=4] = placeholder[target=q]
    %read_K2:0_N:0_M:0_K1:0 :  [num_users=1] = [read](args = (%q, LOAD_ELEMS_PER_THREAD, None, (), None), kwargs = {})
    %read_1 :  [num_users=1] = [read](args = (%q, LOAD_ELEMS_PER_THREAD, None, (), None), kwargs = {})
    %register_K2:0_N:0_M:0_K1:0 :  [num_users=1] = [register](args = ((M, K2), f32, 0.0), kwargs = {})
    %register :  [num_users=1] = [register](args = ((M, K2), f32, 0.0), kwargs = {})
    %mma_K2:0_N:0_M:0_K1:0 :  [num_users=1] = [mma](args = (%read_K2:0_N:0_M:0_K1:0, %read_K2:0_N:0_M:0_K1:0, %register_K2:0_N:0_M:0_K1:0, None), kwargs = {})
    %mma :  [num_users=1] = [mma](args = (%read_1, %read, %register, None), kwargs = {})
    %permute_K2:0_N:0_M:0 :  [num_users=1] = [permute](args = (%mma_K2:0_N:0_M:0_K1:0, [K2, M]), kwargs = {})
    %permute :  [num_users=1] = [permute](args = (%mma, [K2, M]), kwargs = {})
    %do :  [num_users=3] = placeholder[target=do]
    %read_K2:0_N:0_M:0 :  [num_users=1] = [read](args = (%do, LOAD_ELEMS_PER_THREAD, None, (), None), kwargs = {})
    %read_K2:0_N:1_M:0 :  [num_users=1] = [read](args = (%do, LOAD_ELEMS_PER_THREAD, None, (), None), kwargs = {})
    %read_2 :  [num_users=1] = [read](args = (%do, LOAD_ELEMS_PER_THREAD, None, (), None), kwargs = {})
    %cast_K2:0_N:0_M:0 :  [num_users=2] = [cast](args = (%permute_K2:0_N:0_M:0, f16), kwargs = {})
    %cast :  [num_users=1] = [cast](args = (%permute, f16), kwargs = {})
    %mma_K2:0_N:0_M:0 :  [num_users=0] = [mma](args = (%cast_K2:0_N:0_M:0, %read_K2:0_N:0_M:0, %dv_K2:0_N:0_M:0, None), kwargs = {})
    %mma_K2:0_N:1_M:0 :  [num_users=0] = [mma](args = (%cast_K2:0_N:0_M:0, %read_K2:0_N:1_M:0, %dv_K2:0_N:1_M:0, None), kwargs = {})
    %mma_1 :  [num_users=1] = [mma](args = (%cast, %read_2, %dv_prev, None), kwargs = {})
    %ds :  [num_users=2] = placeholder[target=ds]
    %read_K2:0_M:0 :  [num_users=1] = [read](args = (%ds, LOAD_ELEMS_PER_THREAD, None, (), None), kwargs = {})
    %read_3 :  [num_users=1] = [read](args = (%ds, LOAD_ELEMS_PER_THREAD, None, (), None), kwargs = {})
    %read_K2:0_M:0 :  [num_users=1] = [read](args = (%q, LOAD_ELEMS_PER_THREAD, IndexMapping(iters={$index0: 0, $index1: 1}, input_mapping={M: $index1, K1: $index0}), output_mapping={K1: $index0, M: $index1}, dynamic_val_mappings=(), (), None), kwargs = {})
    %read_4 :  [num_users=1] = [read](args = (%q, LOAD_ELEMS_PER_THREAD, IndexMapping(iters={$index0: 0, $index1: 1}, input_mapping={M: $index1, K1: $index0}), output_mapping={K1: $index0, M: $index1}, dynamic_val_mappings=(), (), None), kwargs = {})
    %mma_K2:0_M:0 :  [num_users=0] = [mma](args = (%read_K2:0_M:0, %read_K2:0_M:0, %dk_K2:0_M:0, None), kwargs = {})
    %mma_2 :  [num_users=1] = [mma](args = (%read_3, %read_4, %dk_prev, None), kwargs = {})
    return (mma_1, mma_2)
region_1 [root]:
graph():
    %q :  [num_users=1] = placeholder[target=q]
    %k :  [num_users=1] = placeholder[target=k]
    %do :  [num_users=1] = placeholder[target=do]
    %dk :  [num_users=2] = placeholder[target=dk]
    %dv :  [num_users=3] = placeholder[target=dv]
    %ds :  [num_users=1] = placeholder[target=ds]
    %register_K2:0_N:0_M:0 :  [num_users=0] = [register](args = ((K2, N), f32, 0.0), kwargs = {})
    %register_K2:0_N:1_M:0 :  [num_users=0] = [register](args = ((K2, N), f32, 0.0), kwargs = {})
    %register :  [num_users=1] = [register](args = ((K2, N), f32, 0.0), kwargs = {})
    %register_K2:0_M:0 :  [num_users=0] = [register](args = ((K2, K1), f32, 0.0), kwargs = {})
    %register_1 :  [num_users=1] = [register](args = ((K2, K1), f32, 0.0), kwargs = {})
    %reduction : [Register[K2, N].of(f32), Register[K2, K1].of(f32)] [num_users=5] = [reduction](args = (M, [%register, %register_1], region_0, [%k, %q, %do, %ds]), kwargs = {})
    %getitem_K2:0_N:0_M:0 :  [num_users=1] = [getitem](args = (%reduction, 0), kwargs = {})
    %getitem_K2:0_N:1_M:0 :  [num_users=1] = [getitem](args = (%reduction, 0), kwargs = {})
    %getitem :  [num_users=1] = [getitem](args = (%reduction, 0), kwargs = {})
    %getitem_K2:0_M:0 :  [num_users=1] = [getitem](args = (%reduction, 1), kwargs = {})
    %getitem_1 :  [num_users=1] = [getitem](args = (%reduction, 1), kwargs = {})
    %write_K2:0_N:0_M:0 :  [num_users=0] = [write](args = (%getitem_K2:0_N:0_M:0, %dv, STORE_ELEMS_PER_THREAD, None, ()), kwargs = {})
    %write_K2:0_N:1_M:0 :  [num_users=0] = [write](args = (%getitem_K2:0_N:1_M:0, %dv, STORE_ELEMS_PER_THREAD, None, ()), kwargs = {})
    %write :  [num_users=0] = [write](args = (%getitem, %dv, STORE_ELEMS_PER_THREAD, None, ()), kwargs = {})
    %write_K2:0_M:0 :  [num_users=0] = [write](args = (%getitem_K2:0_M:0, %dk, STORE_ELEMS_PER_THREAD, None, ()), kwargs = {})
    %write_1 :  [num_users=0] = [write](args = (%getitem_1, %dk, STORE_ELEMS_PER_THREAD, None, ()), kwargs = {})
    return None

because then after fixup_reduction_nodes we get this

After `fixup_reduction_nodes`
After fixup_reduction_nodes
region_0:
graph():
    %dv_K2:0_N:0_M:0 :  [num_users=1] = placeholder[target=dv_prev]
    %dv_K2:0_N:1_M:0 :  [num_users=1] = placeholder[target=dv_prev]
    %dk_K2:0_M:0 :  [num_users=1] = placeholder[target=dk_prev]
    %k :  [num_users=1] = placeholder[target=k]
    %read_K2:0_N:0_M:0_K1:0 :  [num_users=1] = [read](args = (%k, LOAD_ELEMS_PER_THREAD, None, (), None), kwargs = {})
    %q :  [num_users=2] = placeholder[target=q]
    %read_K2:0_N:0_M:0_K1:0 :  [num_users=1] = [read](args = (%q, LOAD_ELEMS_PER_THREAD, None, (), None), kwargs = {})
    %register_K2:0_N:0_M:0_K1:0 :  [num_users=1] = [register](args = ((M, K2), f32, 0.0), kwargs = {})
    %mma_K2:0_N:0_M:0_K1:0 :  [num_users=1] = [mma](args = (%read_K2:0_N:0_M:0_K1:0, %read_K2:0_N:0_M:0_K1:0, %register_K2:0_N:0_M:0_K1:0, None), kwargs = {})
    %permute_K2:0_N:0_M:0 :  [num_users=1] = [permute](args = (%mma_K2:0_N:0_M:0_K1:0, [K2, M]), kwargs = {})
    %do :  [num_users=2] = placeholder[target=do]
    %read_K2:0_N:0_M:0 :  [num_users=1] = [read](args = (%do, LOAD_ELEMS_PER_THREAD, None, (), None), kwargs = {})
    %read_K2:0_N:1_M:0 :  [num_users=1] = [read](args = (%do, LOAD_ELEMS_PER_THREAD, None, (), None), kwargs = {})
    %cast_K2:0_N:0_M:0 :  [num_users=2] = [cast](args = (%permute_K2:0_N:0_M:0, f16), kwargs = {})
    %mma_K2:0_N:0_M:0 :  [num_users=1] = [mma](args = (%cast_K2:0_N:0_M:0, %read_K2:0_N:0_M:0, %dv_K2:0_N:0_M:0, None), kwargs = {})
    %mma_K2:0_N:1_M:0 :  [num_users=1] = [mma](args = (%cast_K2:0_N:0_M:0, %read_K2:0_N:1_M:0, %dv_K2:0_N:1_M:0, None), kwargs = {})
    %ds :  [num_users=1] = placeholder[target=ds]
    %read_K2:0_M:0 :  [num_users=1] = [read](args = (%ds, LOAD_ELEMS_PER_THREAD, None, (), None), kwargs = {})
    %read_K2:0_M:0 :  [num_users=1] = [read](args = (%q, LOAD_ELEMS_PER_THREAD, IndexMapping(iters={$index0: 0, $index1: 1}, input_mapping={M: $index1, K1: $index0}), output_mapping={K1: $index0, M: $index1}, dynamic_val_mappings=(), (), None), kwargs = {})
    %mma_K2:0_M:0 :  [num_users=0] = [mma](args = (%read_K2:0_M:0, %read_K2:0_M:0, %dk_K2:0_M:0, None), kwargs = {})
    return [mma_K2:0_N:0_M:0, mma_K2:0_N:1_M:0]
region_1 [root]:
graph():
    %q :  [num_users=1] = placeholder[target=q]
    %k :  [num_users=1] = placeholder[target=k]
    %do :  [num_users=1] = placeholder[target=do]
    %dk :  [num_users=2] = placeholder[target=dk]
    %dv :  [num_users=3] = placeholder[target=dv]
    %ds :  [num_users=1] = placeholder[target=ds]
    %register_K2:0_N:0_M:0 :  [num_users=1] = [register](args = ((K2, N), f32, 0.0), kwargs = {})
    %register_K2:0_N:1_M:0 :  [num_users=1] = [register](args = ((K2, N), f32, 0.0), kwargs = {})
    %register :  [num_users=0] = [register](args = ((K2, N), f32, 0.0), kwargs = {})
    %register_K2:0_M:0 :  [num_users=0] = [register](args = ((K2, K1), f32, 0.0), kwargs = {})
    %register_1 :  [num_users=0] = [register](args = ((K2, K1), f32, 0.0), kwargs = {})
    %reduction : [Register[K2, N].of(f32), Register[K2, K1].of(f32)] [num_users=5] = [reduction](args = (M, [%register_K2:0_N:0_M:0, %register_K2:0_N:1_M:0], region_0, [%k, %q, %do, %ds]), kwargs = {})
    %getitem_K2:0_N:0_M:0 :  [num_users=1] = [get_result](args = (%reduction, 0), kwargs = {})
    %getitem_K2:0_N:1_M:0 :  [num_users=1] = [get_result](args = (%reduction, 1), kwargs = {})
    %getitem :  [num_users=1] = [getitem](args = (%reduction, 0), kwargs = {})
    %getitem_K2:0_M:0 :  [num_users=1] = [getitem](args = (%reduction, 1), kwargs = {})
    %getitem_1 :  [num_users=1] = [getitem](args = (%reduction, 1), kwargs = {})
    %write_K2:0_N:0_M:0 :  [num_users=0] = [write](args = (%getitem_K2:0_N:0_M:0, %dv, STORE_ELEMS_PER_THREAD, None, ()), kwargs = {})
    %write_K2:0_N:1_M:0 :  [num_users=0] = [write](args = (%getitem_K2:0_N:1_M:0, %dv, STORE_ELEMS_PER_THREAD, None, ()), kwargs = {})
    %write :  [num_users=0] = [write](args = (%getitem, %dv, STORE_ELEMS_PER_THREAD, None, ()), kwargs = {})
    %write_K2:0_M:0 :  [num_users=0] = [write](args = (%getitem_K2:0_M:0, %dk, STORE_ELEMS_PER_THREAD, None, ()), kwargs = {})
    %write_1 :  [num_users=0] = [write](args = (%getitem_1, %dk, STORE_ELEMS_PER_THREAD, None, ()), kwargs = {})
    return None

The init args related to register_1 have been lost. This then results in an error further along when emitting the MLIR.

FYI @harsh-nod

@GMNGeoffrey
Copy link
Contributor Author

GMNGeoffrey commented Jan 13, 2025

I think the issue is in compute_result_index and has to do with the results getting expanded along different dimensions. Which I think comes from dim_scaling not including the N dimension, which based on get_dim_scaling, looks to be because N isn't in the vector shaps for the reduction. Setting it explicitly in the hardware constraints seems to fix the issue, but it seems like this shouldn't be necessary (and neither should the degenerate distribution)

@GMNGeoffrey
Copy link
Contributor Author

Also setting it explicitly in the vector shapes breaks other MMA variants and shape combinations

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant