-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
bwd_prepare_wy_repr_kernel Scratch Space Exceeds HW Supported Limit #3449
Comments
Hi @uniartisan , Thanks for reporting the issue on XPU backend. I tried to reproduce the issue on the A770 platform but got some different error. After fixes those errors, I can run the test script successfully without any issue. Can you give the information about on which Intel GPU you get this error? import torch
for i in range(torch.xpu.device_count()):
print(torch.xpu.get_device_capability(i)) issues not related to the original one reported.
import math
import torch
from torch import nn
import torch.nn.functional as F
from fla.layers import RWKV7Attention # type: ignore
from fla.utils import device
class TMix(nn.Module):
def __init__(self, dim, block_id, n_blocks):
super().__init__()
self.rwkv7 = RWKV7Attention(
"chunk",
dim,
layer_idx=block_id
)
def forward(self, x, v_first):
x_attn, _, v_first = self.rwkv7(x, v_first=v_first)
return x_attn, v_first
class CMix(nn.Module):
def __init__(self, dim, hidden_dim, block_id, n_blocks):
super().__init__()
self.value = nn.Linear(dim, dim)
def forward(self, x):
return self.value(x)
class RWKV7Block(nn.Module):
def __init__(self, dim, block_id, n_blocks):
super().__init__()
self.attn = TMix(dim, block_id, n_blocks)
self.mlp = CMix(dim, dim * 4, block_id, n_blocks)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
def forward(self, x, v_first):
x_attn, v_first = self.attn(self.norm1(x), v_first=v_first)
x = x + x_attn
x = x + self.mlp(self.norm2(x))
return x, v_first
class RWKV7(nn.Module):
def __init__(self, vocab_size, dim, n_blocks: int):
super().__init__()
self.wte = nn.Embedding(vocab_size, dim)
self.blocks = nn.ModuleList([
RWKV7Block(dim, i, n_blocks)
for i in range(n_blocks)
])
self.lmhead = nn.Linear(dim, vocab_size)
self.norm_in = nn.LayerNorm(dim)
self.norm_out = nn.LayerNorm(dim)
def forward(self, x: torch.Tensor):
x = self.norm_in(self.wte(x))
# v_first = None
v_first = x.clone()
for block in self.blocks:
x, v_first = block(x, v_first)
return self.lmhead(self.norm_out(x))
bs, seq_len = 1, 4096
data = torch.randint(0, 1024, (bs, seq_len + 1)).to(device)
x = data[:, :-1]
target = data[:, 1:]
model = RWKV7(1024, 4096, 1).to(device)
y = model(x)
loss = F.cross_entropy(y.view(-1, 1024), target.view(-1))
loss.backward()
|
Thank you for your response. It seems that you installed an old version. You can upgrade to the latest rwkv-fla to fix this. Run pip install rwkv-fla==0.7.202502171252. In this version, I use Triton to get the multi processor count and fix all CUDA - specific platforms >>> torch.xpu.get_device_capability(0)
{'architecture': 13115590656, 'driver_version': '1.3.29735+27', 'gpu_eu_count': 512, 'gpu_subslice_count': 32, 'has_atomic64': True, 'has_bfloat16_conversions': True, 'has_fp16': True, 'has_fp64': False, 'has_subgroup_2d_block_io': False, 'has_subgroup_matrix_multiply_accumulate': True, 'has_subgroup_matrix_multiply_accumulate_tensor_float32': False, 'max_compute_units': 512, 'max_num_sub_groups': 128, 'max_work_group_size': 1024, 'name': 'Intel(R) Arc(TM) A770 Graphics', 'platform_name': 'Intel(R) oneAPI Unified Runtime over Level-Zero', 'sub_group_sizes': [8, 16, 32], 'total_memory': 16225243136, 'type': 'gpu', 'vendor': 'Intel(R) Corporation', 'version': '12.55.8'} |
I can reproduce your issue after updating the We only uses the inner production by FMA to do the computation for There are two ways to reduce the register spaces for your kernel in quick:
The suggestion is try to use this new autotune configuration for your kernel: @triton.heuristics({
'USE_OFFSETS': lambda args: args['offsets'] is not None
})
@triton.autotune(
configs=[
triton.Config({'grf_mode': 'large'}, num_warps=num_warps)
for num_warps in [8, 16, 32]
],
key=['BT', 'BK', 'BV']
)
@triton.jit(do_not_specialize=['T'])
def bwd_prepare_wy_repr_kernel( It works in my env. Please have a try. |
Thank you for your reprouction :) |
Here is the general introduction about the GRF mode. https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2023-2/small-register-mode-vs-large-register-mode.html In short, it is a trade off the maximum number of threads scheduled for more register space per thread. In Triton XPU backend, it is preferred to eliminate the register spilling for kernel. The compiler pipeline will chose the large GRF mode if register spilling size is over 1k in small GRF mode. To pass the GRF mode thru the key word argument is a explicitly way to override the default behavior by the user. |
Thank you |
Describe the bug
Dear Intel Triton Team:
I'm encountering an issue with the
bwd_prepare_wy_repr_kernel
in my Python code that uses Triton for kernel implementation. The error message indicates that the total scratch space exceeds the hardware-supported limit for the kernel.Error Details
Code Context
The following is the code snippet for the
bwd_prepare_wy_repr_kernel
and the relatedchunk_dplr_bwd_wy
function:Test Method
I installed
rwkv-fla
usingpip install rwkv-fla
and then ran the following script:Expected Behavior
The code should run without any issues (except for incorrect tiling that causes an overflow of the global shared memory).
Environment details
Package Version
importlib_metadata 8.0.0
inflect 7.3.1
iniconfig 2.0.0
intel-cmplr-lib-rt 2025.0.2
intel-cmplr-lib-ur 2025.0.2
intel-cmplr-lic-rt 2025.0.2
intel-openmp 2025.0.2
intel-pti 0.10.0
intel-sycl-rt 2025.0.2
pytorch-triton-xpu 3.2.0
torch 2.6.0+xpu
torchaudio 2.6.0+xpu
torchvision 0.21.0+xpu
The text was updated successfully, but these errors were encountered: