Fix shared memory OOM on sm86 GPUs. (#4797)

This commit is contained in:
Yi Pan
2025-03-27 01:41:53 +08:00
committed by GitHub
parent d89c0e4b7e
commit 45fdf1f7f3
2 changed files with 4 additions and 4 deletions

View File

@@ -341,8 +341,8 @@ def extend_attention_fwd(
else:
BLOCK_M, BLOCK_N = (32, 64)
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
# 8.9 has a much smaller shared memory size (100K) than 8.0 (160K)
if CUDA_CAPABILITY[1] == 9:
# sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K)
if CUDA_CAPABILITY[1] == 9 or CUDA_CAPABILITY[1] == 6:
if Lq <= 128:
BLOCK_M, BLOCK_N = (64, 128)
elif Lq <= 256: