Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bwd_prepare_wy_repr_kernel Scratch Space Exceeds HW Supported Limit #3449

Closed
uniartisan opened this issue Feb 17, 2025 · 6 comments
Closed
Assignees
Labels
bug Something isn't working community

Comments

@uniartisan
Copy link

Describe the bug

Dear Intel Triton Team:

I'm encountering an issue with the bwd_prepare_wy_repr_kernel in my Python code that uses Triton for kernel implementation. The error message indicates that the total scratch space exceeds the hardware-supported limit for the kernel.

Error Details

error: total scratch space exceeds HW supported limit for kernel bwd_prepare_wy_repr_kernel: 594720 bytes (max permitted PTSS 262144 bytes)
error: backend compiler failed build.

Code Context

The following is the code snippet for the bwd_prepare_wy_repr_kernel and the related chunk_dplr_bwd_wy function:

# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang

from typing import Optional, Tuple

import torch
import triton
import triton.language as tl
from fla.utils import device_capacity


@triton.heuristics({
    'USE_OFFSETS': lambda args: args['offsets'] is not None
})
@triton.autotune(
    configs=[
        triton.Config({}, num_warps=num_warps)
        for num_warps in [1, 2, 4, 8]
    ],
    key=['BT', 'BK', 'BV']
)
@triton.jit(do_not_specialize=['T'])
def bwd_prepare_wy_repr_kernel(
    A_ab_inv,
    A_ak,
    ag,
    v,
    dw,
    du,
    dv,
    dag,
    dAak,
    dAab,
    offsets,
    indices,
    T,
    H: tl.constexpr,
    K: tl.constexpr,
    V: tl.constexpr,
    BT: tl.constexpr,
    BK: tl.constexpr,
    BV: tl.constexpr,
    USE_OFFSETS: tl.constexpr,
    HEAD_FIRST: tl.constexpr
):
    i_t, i_bh = tl.program_id(0), tl.program_id(1)
    i_b, i_h = i_bh // H, i_bh % H
    if USE_OFFSETS:
        i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
        bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
        T = eos - bos
    else:
        bos, eos = i_b * T, i_b * T + T

    if HEAD_FIRST:
        p_Aab_inv_t = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
        p_Aak_t = tl.make_block_ptr(A_ak + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
        p_dAak = tl.make_block_ptr(dAak + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
        p_dAab = tl.make_block_ptr(dAab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
    else:
        p_Aak_t = tl.make_block_ptr(A_ak + (bos*H + i_h) * BT,  (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
        p_Aab_inv_t = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
        p_dAak = tl.make_block_ptr(dAak + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
        p_dAab = tl.make_block_ptr(dAab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))

    b_A_ab_inv_t = tl.load(p_Aab_inv_t, boundary_check=(0, 1))
    b_A_ak_t = tl.load(p_Aak_t, boundary_check=(0, 1))
    b_A_ak_t = tl.where(tl.arange(0, BT)[:, None] < tl.arange(0, BT)[None, :], b_A_ak_t, 0)
    b_A_ab_inv_t = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A_ab_inv_t, 0)
    b_A_tmp_t = tl.dot(b_A_ak_t, b_A_ab_inv_t).to(v.dtype.element_ty)
    b_dA_tmp = tl.zeros([BT, BT], dtype=tl.float32)

    for i_v in range(tl.cdiv(V, BV)):
        if HEAD_FIRST:
            p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
            p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
            p_du = tl.make_block_ptr(du + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
        else:
            p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
            p_dv = tl.make_block_ptr(dv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
            p_du = tl.make_block_ptr(du + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
        b_v = tl.load(p_v, boundary_check=(0, 1))
        b_du = tl.load(p_du, boundary_check=(0, 1))
        b_dA_tmp += tl.dot(b_du.to(b_v.dtype), tl.trans(b_v))
        b_dv = tl.dot(b_A_tmp_t, b_du)
        tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))

    b_dA_tmp = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA_tmp, 0)
    b_dA_ak = tl.dot(b_A_ab_inv_t, b_dA_tmp)
    b_dA_ak = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA_ak, 0)
    tl.store(p_dAak, b_dA_ak, boundary_check=(0, 1))
    b_dA_ab_inv = tl.dot(b_dA_tmp, b_A_ak_t)

    for i_k in range(tl.cdiv(K, BK)):
        if HEAD_FIRST:
            p_ag = tl.make_block_ptr(ag + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
            p_dag = tl.make_block_ptr(dag + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
            p_dw = tl.make_block_ptr(dw + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
        else:
            p_ag = tl.make_block_ptr(ag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
            p_dag = tl.make_block_ptr(dag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
            p_dw = tl.make_block_ptr(dw + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
        b_ag = tl.load(p_ag, boundary_check=(0, 1))
        b_dw = tl.load(p_dw, boundary_check=(0, 1))
        b_dA_ab_inv += tl.dot(b_dw, tl.trans(b_ag))
        b_dag = tl.dot(b_A_ab_inv_t.to(b_dw.dtype), b_dw)
        tl.store(p_dag, b_dag.to(p_dag.dtype.element_ty), boundary_check=(0, 1))

    # if we know dL/dA^(-1), for dL/dA, we can use the following formula:
    # dL/dA = -(A^(-1))^T @ (dL/dA^(-1)) @ (A^(-1))^T
    # in the fwd pass we use fwd substitution to calculate (I-lower(A_ab))^-1.
    # denote A = I - lower(A_ab), B = A^-1
    # in the backward pass.
    # dL/dA = -(B)^T @ (dL/dB) @ B^T
    # dL/dA_ab = lower(B^T @ dL/dB @ B^T)
    b_dA_ab_inv = tl.where(tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :], b_dA_ab_inv, 0)
    b_dA_ab_inv = tl.dot(b_A_ab_inv_t, b_dA_ab_inv)
    b_dA_ab_inv = tl.dot(b_dA_ab_inv, b_A_ab_inv_t)
    b_dA_ab_inv = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA_ab_inv, 0)
    tl.store(p_dAab, b_dA_ab_inv, boundary_check=(0, 1))


def chunk_dplr_bwd_wy(
    A_ab_inv: torch.Tensor,
    A_ak: torch.Tensor,
    v: torch.Tensor,
    ag: torch.Tensor,
    dw: torch.Tensor,
    du: torch.Tensor,
    offsets: Optional[torch.LongTensor],
    indices: Optional[torch.LongTensor],
    head_first: bool,
    chunk_size: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    A_ab_inv, A_ak, v, ag, dw, du = map(lambda x: x.contiguous(), [A_ab_inv, A_ak, v, ag, dw, du])
    if head_first:
        B, H, T, K, V = *dw.shape, du.shape[-1]
    else:
        B, T, H, K, V = *dw.shape, du.shape[-1]
    BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
    if offsets is None:
        NT = triton.cdiv(T, BT)
    else:
        if indices is None:
            indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], BT).tolist()])
            indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets)
        NT = len(indices)
    BK = min(triton.next_power_of_2(K), 64)
    BV = min(triton.next_power_of_2(V), 64) if device_capacity else  min(triton.next_power_of_2(V), 32)

    dA_ab = torch.empty_like(A_ab_inv, dtype=torch.float)
    dA_ak = torch.empty_like(A_ak, dtype=torch.float)
    dv = torch.empty_like(v)
    dag = torch.empty_like(ag)

    bwd_prepare_wy_repr_kernel[(NT, B * H)](
        A_ab_inv=A_ab_inv,
        A_ak=A_ak,
        ag=ag,
        v=v,
        dw=dw,
        du=du,
        dv=dv,
        dag=dag,
        dAak=dA_ak,
        dAab=dA_ab,
        offsets=offsets,
        indices=indices,
        T=T,
        H=H,
        K=K,
        V=V,
        BT=BT,
        BK=BK,
        BV=BV,
        HEAD_FIRST=head_first
    )
    return dA_ab, dA_ak, dv, dag

Test Method

I installed rwkv-fla using pip install rwkv-fla and then ran the following script:

import math
import torch
from torch import nn
import torch.nn.functional as F
from fla.layers import RWKV7Attention # type: ignore
from fla.utils import device

class TMix(nn.Module):
    def __init__(self, dim, block_id, n_blocks):
        super().__init__()
        self.rwkv7 = RWKV7Attention(
            "chunk",
            dim,
            layer_idx=block_id
        )

    def forward(self, x, v_first):
        x_attn, _, past_key_values, v_first = self.rwkv7(x, v_first=v_first)
        return x_attn, v_first

class CMix(nn.Module):
    def __init__(self, dim, hidden_dim, block_id, n_blocks):
        super().__init__()
        self.value = nn.Linear(dim, dim)

    def forward(self, x):
        return self.value(x)

class RWKV7Block(nn.Module):
    def __init__(self, dim, block_id, n_blocks):
        super().__init__()
        self.attn = TMix(dim, block_id, n_blocks)
        self.mlp = CMix(dim, dim * 4, block_id, n_blocks)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, x, v_first):
        x_attn, v_first = self.attn(self.norm1(x), v_first=v_first)
        x = x + x_attn
        x = x + self.mlp(self.norm2(x))
        return x, v_first

class RWKV7(nn.Module):
    def __init__(self, vocab_size, dim, n_blocks: int):
        super().__init__()
        self.wte = nn.Embedding(vocab_size, dim)
        self.blocks = nn.ModuleList([
            RWKV7Block(dim, i, n_blocks)
            for i in range(n_blocks)
        ])
        self.lmhead = nn.Linear(dim, vocab_size)
        self.norm_in = nn.LayerNorm(dim)
        self.norm_out = nn.LayerNorm(dim)

    def forward(self, x: torch.Tensor):
        x = self.norm_in(self.wte(x))
        v_first = None
        for block in self.blocks:
            x, v_first = block(x, v_first)
        return self.lmhead(self.norm_out(x))

bs, seq_len = 1, 4096
data = torch.randint(0, 1024, (bs, seq_len + 1)).to(device)
x = data[:, :-1]
target = data[:, 1:]

model = RWKV7(1024, 4096, 1).to(device)
y = model(x)

loss = F.cross_entropy(y.view(-1, 1024), target.view(-1))
loss.backward()

Expected Behavior

The code should run without any issues (except for incorrect tiling that causes an overflow of the global shared memory).

Environment details

Package Version


importlib_metadata 8.0.0
inflect 7.3.1
iniconfig 2.0.0
intel-cmplr-lib-rt 2025.0.2
intel-cmplr-lib-ur 2025.0.2
intel-cmplr-lic-rt 2025.0.2
intel-openmp 2025.0.2
intel-pti 0.10.0
intel-sycl-rt 2025.0.2
pytorch-triton-xpu 3.2.0
torch 2.6.0+xpu
torchaudio 2.6.0+xpu
torchvision 0.21.0+xpu

@uniartisan uniartisan added the bug Something isn't working label Feb 17, 2025
@chengjunlu
Copy link
Contributor

Hi @uniartisan ,

Thanks for reporting the issue on XPU backend.

I tried to reproduce the issue on the A770 platform but got some different error.

After fixes those errors, I can run the test script successfully without any issue.

Can you give the information about on which Intel GPU you get this error?
You can use this to dump the GPU information of your environment.

import torch

for i in range(torch.xpu.device_count()):
    print(torch.xpu.get_device_capability(i))
issues not related to the original one reported.
  1. I am using the this rwkv-fla rwkv-fla 0.7.202501211312.
    There are some in-compatible issue of the test case script. I modify the test case script to run the test. Here is the new test scripts for reference:
import math
import torch
from torch import nn
import torch.nn.functional as F
from fla.layers import RWKV7Attention # type: ignore
from fla.utils import device

class TMix(nn.Module):
    def __init__(self, dim, block_id, n_blocks):
        super().__init__()
        self.rwkv7 = RWKV7Attention(
            "chunk",
            dim,
            layer_idx=block_id
        )

    def forward(self, x, v_first):
        x_attn, _, v_first = self.rwkv7(x, v_first=v_first)
        return x_attn, v_first

class CMix(nn.Module):
    def __init__(self, dim, hidden_dim, block_id, n_blocks):
        super().__init__()
        self.value = nn.Linear(dim, dim)

    def forward(self, x):
        return self.value(x)

class RWKV7Block(nn.Module):
    def __init__(self, dim, block_id, n_blocks):
        super().__init__()
        self.attn = TMix(dim, block_id, n_blocks)
        self.mlp = CMix(dim, dim * 4, block_id, n_blocks)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, x, v_first):
        x_attn, v_first = self.attn(self.norm1(x), v_first=v_first)
        x = x + x_attn
        x = x + self.mlp(self.norm2(x))
        return x, v_first

class RWKV7(nn.Module):
    def __init__(self, vocab_size, dim, n_blocks: int):
        super().__init__()
        self.wte = nn.Embedding(vocab_size, dim)
        self.blocks = nn.ModuleList([
            RWKV7Block(dim, i, n_blocks)
            for i in range(n_blocks)
        ])
        self.lmhead = nn.Linear(dim, vocab_size)
        self.norm_in = nn.LayerNorm(dim)
        self.norm_out = nn.LayerNorm(dim)

    def forward(self, x: torch.Tensor):
        x = self.norm_in(self.wte(x))
        # v_first = None
        v_first = x.clone()
        for block in self.blocks:
            x, v_first = block(x, v_first)
        return self.lmhead(self.norm_out(x))

bs, seq_len = 1, 4096
data = torch.randint(0, 1024, (bs, seq_len + 1)).to(device)
x = data[:, :-1]
target = data[:, 1:]

model = RWKV7(1024, 4096, 1).to(device)
y = model(x)

loss = F.cross_entropy(y.view(-1, 1024), target.view(-1))
loss.backward()
  1. There is some hardcode in the rwkv-fla only to run with CUDA device. I changed it for Intel GPU locally in my test.
    device="CUDA" to device=x.device
    with torch.cuda.device(x.device.index): to with torch.xpu.device(x.device.index):
    torch.cuda.get_device_properties(x.device).multi_processor_count to torch.xpu.get_device_properties(x.device).gpu_subslice_count

@uniartisan
Copy link
Author

uniartisan commented Feb 19, 2025

There is some hardcode in the rwkv-fla only to run with CUDA device. I changed it for Intel GPU locally in my test.
device="CUDA" to device=x.device
with torch.cuda.device(x.device.index): to with torch.xpu.device(x.device.index):
torch.cuda.get_device_properties(x.device).multi_processor_count to torch.xpu.get_device_properties(x.device).gpu_subslice_count

Thank you for your response. It seems that you installed an old version. You can upgrade to the latest rwkv-fla to fix this. Run pip install rwkv-fla==0.7.202502171252. In this version, I use Triton to get the multi processor count and fix all CUDA - specific platforms

>>> torch.xpu.get_device_capability(0)
{'architecture': 13115590656, 'driver_version': '1.3.29735+27', 'gpu_eu_count': 512, 'gpu_subslice_count': 32, 'has_atomic64': True, 'has_bfloat16_conversions': True, 'has_fp16': True, 'has_fp64': False, 'has_subgroup_2d_block_io': False, 'has_subgroup_matrix_multiply_accumulate': True, 'has_subgroup_matrix_multiply_accumulate_tensor_float32': False, 'max_compute_units': 512, 'max_num_sub_groups': 128, 'max_work_group_size': 1024, 'name': 'Intel(R) Arc(TM) A770 Graphics', 'platform_name': 'Intel(R) oneAPI Unified Runtime over Level-Zero', 'sub_group_sizes': [8, 16, 32], 'total_memory': 16225243136, 'type': 'gpu', 'vendor': 'Intel(R) Corporation', 'version': '12.55.8'}

@chengjunlu
Copy link
Contributor

chengjunlu commented Feb 20, 2025

I can reproduce your issue after updating the rwkv-fla to the later version.

We only uses the inner production by FMA to do the computation for tt.dot on A770. The loop unrolling for the inner production is aggressive on Triton. (No loop at all for inner production by FMA.). So it requires a lot of register spaces for a single tile of tt.dot per warp. And finally cause the register spilling size exceeds the HW limitation.

There are two ways to reduce the register spaces for your kernel in quick:

  1. Reduce the tile size per warp of threads block by increasing the num_warps. Intel GPU supports at most up to 32 physical warps per block threads.
  2. Increase the GRF size per thread to reduce the spilling size. Intel support double GRF size per thread by trading off the half number threads executed parallel physically. You can use grf_mode to enable it explicitly for your Triton kernel.

The suggestion is try to use this new autotune configuration for your kernel:

@triton.heuristics({
    'USE_OFFSETS': lambda args: args['offsets'] is not None
})
@triton.autotune(
    configs=[
        triton.Config({'grf_mode': 'large'}, num_warps=num_warps)
        for num_warps in [8, 16, 32]
    ],
    key=['BT', 'BK', 'BV']
)
@triton.jit(do_not_specialize=['T'])
def bwd_prepare_wy_repr_kernel(

It works in my env. Please have a try.

@uniartisan
Copy link
Author

I can reproduce your issue after updating the rwkv-fla to the later version.

We only uses the inner production by FMA to do the computation for tt.dot on A770. The loop unrolling for the inner production is aggressive on Triton. (No loop at all for inner production by FMA.). So it requires a lot of register spaces for a single tile of tt.dot per warp. And finally cause the register spilling size exceeds the HW limitation.

There are two ways to reduce the register spaces for your kernel in quick:

  1. Reduce the tile size per warp of threads block by increasing the num_warps. Intel GPU supports at most up to 32 physical warps per block threads.
  2. Increase the GRF size per thread to reduce the spilling size. Intel support double GRF size per thread by trading off the half number threads executed parallel physically. You can use grf_mode to enable it explicitly for your Triton kernel.

The suggestion is try to use this new autotune configuration for your kernel:

@triton.heuristics({
'USE_OFFSETS': lambda args: args['offsets'] is not None
})
@triton.autotune(
configs=[
triton.Config({'grf_mode': 'large'}, num_warps=num_warps)
for num_warps in [8, 16, 32]
],
key=['BT', 'BK', 'BV']
)
@triton.jit(do_not_specialize=['T'])
def bwd_prepare_wy_repr_kernel(
It works in my env. Please have a try.

Thank you for your reprouction :)
I'd like to ask if there are any negative impacts of the grf_mode? Should this configuration be enabled manually? In previous triton (intel backend) version, I saw from the log that it would automatically attempt to use the GRF.

@chengjunlu
Copy link
Contributor

chengjunlu commented Feb 21, 2025

Here is the general introduction about the GRF mode. https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2023-2/small-register-mode-vs-large-register-mode.html

In short, it is a trade off the maximum number of threads scheduled for more register space per thread.

In Triton XPU backend, it is preferred to eliminate the register spilling for kernel. The compiler pipeline will chose the large GRF mode if register spilling size is over 1k in small GRF mode. To pass the GRF mode thru the key word argument is a explicitly way to override the default behavior by the user.

@uniartisan
Copy link
Author

Thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working community
Projects
None yet
Development

No branches or pull requests

3 participants