Fix shared memory OOM on sm86 GPUs. (#4797)
This commit is contained in:
@@ -341,8 +341,8 @@ def extend_attention_fwd(
|
||||
else:
|
||||
BLOCK_M, BLOCK_N = (32, 64)
|
||||
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
|
||||
# 8.9 has a much smaller shared memory size (100K) than 8.0 (160K)
|
||||
if CUDA_CAPABILITY[1] == 9:
|
||||
# sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K)
|
||||
if CUDA_CAPABILITY[1] == 9 or CUDA_CAPABILITY[1] == 6:
|
||||
if Lq <= 128:
|
||||
BLOCK_M, BLOCK_N = (64, 128)
|
||||
elif Lq <= 256:
|
||||
|
||||
@@ -703,8 +703,8 @@ torch::Tensor int8_scaled_mm(
|
||||
sm75_dispatch_shape<cutlass::half_t, cutlass::arch::Sm75, cutlass::gemm::GemmShape<8, 8, 16>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else if (sm_version >= 80 && sm_version < 90) {
|
||||
// sm89 has a much smaller shared memory size (100K) than sm80 (160K)
|
||||
if (sm_version == 89) {
|
||||
// sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K)
|
||||
if (sm_version == 86 || sm_version == 89) {
|
||||
if (out_dtype == torch::kBFloat16) {
|
||||
sm89_dispatch_shape<cutlass::bfloat16_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
|
||||
Reference in New Issue
Block a user