Fix shared memory OOM on sm86 GPUs. (#4797)
This commit is contained in:
@@ -341,8 +341,8 @@ def extend_attention_fwd(
|
||||
else:
|
||||
BLOCK_M, BLOCK_N = (32, 64)
|
||||
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
|
||||
# 8.9 has a much smaller shared memory size (100K) than 8.0 (160K)
|
||||
if CUDA_CAPABILITY[1] == 9:
|
||||
# sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K)
|
||||
if CUDA_CAPABILITY[1] == 9 or CUDA_CAPABILITY[1] == 6:
|
||||
if Lq <= 128:
|
||||
BLOCK_M, BLOCK_N = (64, 128)
|
||||
elif Lq <= 256:
|
||||
|
||||
Reference in New Issue
Block a user