Fix shared memory OOM on sm86 GPUs. (#4797)
This commit is contained in:
@@ -703,8 +703,8 @@ torch::Tensor int8_scaled_mm(
|
||||
sm75_dispatch_shape<cutlass::half_t, cutlass::arch::Sm75, cutlass::gemm::GemmShape<8, 8, 16>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else if (sm_version >= 80 && sm_version < 90) {
|
||||
// sm89 has a much smaller shared memory size (100K) than sm80 (160K)
|
||||
if (sm_version == 89) {
|
||||
// sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K)
|
||||
if (sm_version == 86 || sm_version == 89) {
|
||||
if (out_dtype == torch::kBFloat16) {
|
||||
sm89_dispatch_shape<cutlass::bfloat16_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
|
||||
Reference in New Issue
Block a user