Skip to content

Commit 3fec0ed

Browse files
committed
Disable splitkv for most cases (only decoding) due to excessive RAM usage for long sequence.
1 parent 645e6f3 commit 3fec0ed

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

lib/nnc/cmd/scaled_dot_product_attention/gpu/ccv_nnc_scaled_dot_product_attention_flash_attn.cu

+2-1
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,8 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c
178178
// In any case we don't expect seqlen_q to be larger than 64 for inference.
179179
const int num_m_blocks = (R + 64 - 1) / 64;
180180
const ccv_nnc_cuda_device_prop_t props = ccv_nnc_gpu_device_props();
181-
params.num_splits = num_splits_heuristic(batch_size * Hq * num_m_blocks, props.multi_processor_count, num_n_blocks, 128);
181+
// Only enable splitkv if R is 1.
182+
params.num_splits = R == 1 ? num_splits_heuristic(batch_size * Hq * num_m_blocks, props.multi_processor_count * 2, num_n_blocks, 128) : 1;
182183
if (saved_softmax_lse)
183184
params.softmax_lse_ptr = saved_softmax_lse->data.u8;
184185
if (params.num_splits > 1)

0 commit comments

Comments
 (0)