Support serving DeepSeek-R1-Channel-INT8 with 32 L40S. (#4418)
This commit is contained in:
@@ -279,6 +279,143 @@ void sm80_dispatch_shape(
|
||||
}
|
||||
}
|
||||
|
||||
// Dispatch shape for sm89 (L40S, L20, RTX 4090), according to:
|
||||
// https://github.com/vllm-project/vllm/blob/main/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh
|
||||
template <typename ElementOutput, typename ArchTag, typename InstructionShape>
|
||||
void sm89_dispatch_shape(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
int m = mat_a.size(0);
|
||||
int n = mat_b.size(1);
|
||||
if (m <= 16) {
|
||||
if (n <= 8192) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<16, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<16, 128, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
InstructionShape,
|
||||
4>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 32) {
|
||||
if (n <= 8192) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<32, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<32, 128, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
InstructionShape,
|
||||
4>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 64) {
|
||||
if (n <= 8192) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<64, 64, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<64, 128, 128>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
3>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 128) {
|
||||
if (n <= 8192) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<64, 128, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
InstructionShape,
|
||||
3>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else if (n <= 16384) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<64, 64, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 256) {
|
||||
if (n <= 4096) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<64, 128, 128>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
3>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else if (n <= 8192) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else if (n <= 16384) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<256, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
3>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<32, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename ElementOutput,
|
||||
typename TileShape,
|
||||
@@ -566,12 +703,23 @@ torch::Tensor int8_scaled_mm(
|
||||
sm75_dispatch_shape<cutlass::half_t, cutlass::arch::Sm75, cutlass::gemm::GemmShape<8, 8, 16>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else if (sm_version >= 80 && sm_version < 90) {
|
||||
if (out_dtype == torch::kBFloat16) {
|
||||
sm80_dispatch_shape<cutlass::bfloat16_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
// sm89 has a much smaller shared memory size (100K) than sm80 (160K)
|
||||
if (sm_version == 89) {
|
||||
if (out_dtype == torch::kBFloat16) {
|
||||
sm89_dispatch_shape<cutlass::bfloat16_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
sm89_dispatch_shape<cutlass::half_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else {
|
||||
sm80_dispatch_shape<cutlass::half_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
if (out_dtype == torch::kBFloat16) {
|
||||
sm80_dispatch_shape<cutlass::bfloat16_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
sm80_dispatch_shape<cutlass::half_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
}
|
||||
} else if (sm_version == 90) {
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
|
||||
@@ -37,7 +37,7 @@ class TestInt8Gemm(unittest.TestCase):
|
||||
print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")
|
||||
|
||||
def test_accuracy(self):
|
||||
Ms = [1, 128, 512, 1024, 4096, 8192]
|
||||
Ms = [1, 16, 32, 64, 128, 512, 1024, 4096, 8192]
|
||||
Ns = [16, 128, 512, 1024, 4096, 8192, 16384]
|
||||
Ks = [512, 1024, 4096, 8192, 16384]
|
||||
bias_opts = [True, False]
|
||||
|
||||
Reference in New Issue
Block a user