Support serving DeepSeek-R1-Channel-INT8 with 32 L40S. (#4418)

This commit is contained in:
Wenbo Yang
2025-03-17 15:03:43 +08:00
committed by GitHub
parent 0f52fb55ec
commit 75b656488a
7 changed files with 489 additions and 11 deletions

View File

@@ -279,6 +279,143 @@ void sm80_dispatch_shape(
}
}
// Dispatch shape for sm89 (L40S, L20, RTX 4090), according to:
// https://github.com/vllm-project/vllm/blob/main/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh
template <typename ElementOutput, typename ArchTag, typename InstructionShape>
void sm89_dispatch_shape(
torch::Tensor& out,
const torch::Tensor& mat_a,
const torch::Tensor& mat_b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
int m = mat_a.size(0);
int n = mat_b.size(1);
if (m <= 16) {
if (n <= 8192) {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<16, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>,
InstructionShape,
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<16, 128, 128>,
cutlass::gemm::GemmShape<16, 64, 64>,
InstructionShape,
4>(out, mat_a, mat_b, scales_a, scales_b, bias);
}
} else if (m <= 32) {
if (n <= 8192) {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<32, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>,
InstructionShape,
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<32, 128, 128>,
cutlass::gemm::GemmShape<32, 64, 64>,
InstructionShape,
4>(out, mat_a, mat_b, scales_a, scales_b, bias);
}
} else if (m <= 64) {
if (n <= 8192) {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<64, 64, 128>,
cutlass::gemm::GemmShape<32, 64, 64>,
InstructionShape,
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<64, 128, 128>,
cutlass::gemm::GemmShape<64, 64, 64>,
InstructionShape,
3>(out, mat_a, mat_b, scales_a, scales_b, bias);
}
} else if (m <= 128) {
if (n <= 8192) {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<64, 128, 128>,
cutlass::gemm::GemmShape<32, 64, 64>,
InstructionShape,
3>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else if (n <= 16384) {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
InstructionShape,
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<64, 64, 128>,
cutlass::gemm::GemmShape<32, 64, 64>,
InstructionShape,
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
}
} else if (m <= 256) {
if (n <= 4096) {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<64, 128, 128>,
cutlass::gemm::GemmShape<64, 64, 64>,
InstructionShape,
3>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else if (n <= 8192) {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
InstructionShape,
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else if (n <= 16384) {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<256, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
InstructionShape,
3>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
InstructionShape,
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
}
} else {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<32, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>,
InstructionShape,
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
}
}
template <
typename ElementOutput,
typename TileShape,
@@ -566,12 +703,23 @@ torch::Tensor int8_scaled_mm(
sm75_dispatch_shape<cutlass::half_t, cutlass::arch::Sm75, cutlass::gemm::GemmShape<8, 8, 16>>(
out, mat_a, mat_b, scales_a, scales_b, bias);
} else if (sm_version >= 80 && sm_version < 90) {
if (out_dtype == torch::kBFloat16) {
sm80_dispatch_shape<cutlass::bfloat16_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
out, mat_a, mat_b, scales_a, scales_b, bias);
// sm89 has a much smaller shared memory size (100K) than sm80 (160K)
if (sm_version == 89) {
if (out_dtype == torch::kBFloat16) {
sm89_dispatch_shape<cutlass::bfloat16_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
out, mat_a, mat_b, scales_a, scales_b, bias);
} else {
sm89_dispatch_shape<cutlass::half_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
out, mat_a, mat_b, scales_a, scales_b, bias);
}
} else {
sm80_dispatch_shape<cutlass::half_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
out, mat_a, mat_b, scales_a, scales_b, bias);
if (out_dtype == torch::kBFloat16) {
sm80_dispatch_shape<cutlass::bfloat16_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
out, mat_a, mat_b, scales_a, scales_b, bias);
} else {
sm80_dispatch_shape<cutlass::half_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
out, mat_a, mat_b, scales_a, scales_b, bias);
}
}
} else if (sm_version == 90) {
#if defined CUDA_VERSION && CUDA_VERSION >= 12000

View File

@@ -37,7 +37,7 @@ class TestInt8Gemm(unittest.TestCase):
print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")
def test_accuracy(self):
Ms = [1, 128, 512, 1024, 4096, 8192]
Ms = [1, 16, 32, 64, 128, 512, 1024, 4096, 8192]
Ns = [16, 128, 512, 1024, 4096, 8192, 16384]
Ks = [512, 1024, 4096, 8192, 16384]
bias_opts = [True, False]