Support serving DeepSeek-R1-Channel-INT8 with 32 L40S. (#4418)
This commit is contained in:
@@ -218,6 +218,33 @@ python3 -m sglang.bench_serving --dataset-path /path/to/ShareGPT_V3_unfiltered_c
|
||||
|
||||
> **Note: using `--parallel 200` can accelerate accuracy benchmarking**.
|
||||
|
||||
### Example: Serving with 32 L40S with int8 Quantization
|
||||
|
||||
Running with per-channel quantization model:
|
||||
|
||||
- [meituan/DeepSeek-R1-Channel-INT8](https://huggingface.co/meituan/DeepSeek-R1-Channel-INT8)
|
||||
|
||||
Assuming that master node IP is `MASTER_IP`, checkpoint path is `/path/to/DeepSeek-R1-Channel-INT8` and port=5000, we can have following commands to launch the server:
|
||||
|
||||
```bash
|
||||
#master
|
||||
python3 -m sglang.launch_server --model meituan/DeepSeek-R1-Channel-INT8 --tp 32 --quantization w8a8_int8 \
|
||||
--dist-init-addr MASTER_IP:5000 --nnodes 4 --node-rank 0 --trust-remote \
|
||||
--enable-torch-compile --torch-compile-max-bs 32
|
||||
#cluster
|
||||
python3 -m sglang.launch_server --model meituan/DeepSeek-R1-Channel-INT8 --tp 32 --quantization w8a8_int8 \
|
||||
--dist-init-addr MASTER_IP:5000 --nnodes 4 --node-rank 1 --trust-remote \
|
||||
--enable-torch-compile --torch-compile-max-bs 32
|
||||
python3 -m sglang.launch_server --model meituan/DeepSeek-R1-Channel-INT8 --tp 32 --quantization w8a8_int8 \
|
||||
--dist-init-addr MASTER_IP:5000 --nnodes 4 --node-rank 2 --trust-remote \
|
||||
--enable-torch-compile --torch-compile-max-bs 32
|
||||
python3 -m sglang.launch_server --model meituan/DeepSeek-R1-Channel-INT8 --tp 32 --quantization w8a8_int8 \
|
||||
--dist-init-addr MASTER_IP:5000 --nnodes 4 --node-rank 3 --trust-remote \
|
||||
--enable-torch-compile --torch-compile-max-bs 32
|
||||
```
|
||||
|
||||
The benchmarking method is the same as describted in the previous [16 x A100](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-16-a100a800-with-int8-quantization) example.
|
||||
|
||||
### Example: Serving on any cloud or Kubernetes with SkyPilot
|
||||
|
||||
SkyPilot helps find cheapest available GPUs across any cloud or existing Kubernetes clusters and launch distributed serving with a single command. See details [here](https://github.com/skypilot-org/skypilot/tree/master/llm/deepseek-r1).
|
||||
|
||||
@@ -18,6 +18,7 @@ SGLang is recognized as one of the top engines for [DeepSeek model inference](ht
|
||||
| **Quantized weights (AWQ)** | 8 x H100/800/20 |
|
||||
| | 8 x A100/A800 |
|
||||
| **Quantized weights (int8)** | 16 x A100/800 |
|
||||
| | 32 x L40S |
|
||||
|
||||
<style>
|
||||
.md-typeset__table {
|
||||
@@ -56,6 +57,7 @@ Detailed commands for reference:
|
||||
- [4 x 8 x A100](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-four-a1008-nodes)
|
||||
- [8 x A100 (AWQ)](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-8-a100a800-with-awq-quantization)
|
||||
- [16 x A100 (int8)](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-16-a100a800-with-int8-quantization)
|
||||
- [32 x L40S (int8)](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-32-l40s-with-int8-quantization)
|
||||
|
||||
### Download Weights
|
||||
|
||||
|
||||
@@ -341,12 +341,21 @@ def extend_attention_fwd(
|
||||
else:
|
||||
BLOCK_M, BLOCK_N = (32, 64)
|
||||
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
|
||||
if Lq <= 128:
|
||||
BLOCK_M, BLOCK_N = (128, 128)
|
||||
elif Lq <= 256:
|
||||
BLOCK_M, BLOCK_N = (64, 64)
|
||||
# 8.9 has a much smaller shared memory size (100K) than 8.0 (160K)
|
||||
if CUDA_CAPABILITY[1] == 9:
|
||||
if Lq <= 128:
|
||||
BLOCK_M, BLOCK_N = (64, 128)
|
||||
elif Lq <= 256:
|
||||
BLOCK_M, BLOCK_N = (64, 64)
|
||||
else:
|
||||
BLOCK_M, BLOCK_N = (32, 32)
|
||||
else:
|
||||
BLOCK_M, BLOCK_N = (32, 64)
|
||||
if Lq <= 128:
|
||||
BLOCK_M, BLOCK_N = (128, 128)
|
||||
elif Lq <= 256:
|
||||
BLOCK_M, BLOCK_N = (64, 64)
|
||||
else:
|
||||
BLOCK_M, BLOCK_N = (32, 64)
|
||||
else:
|
||||
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
|
||||
|
||||
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@@ -279,6 +279,143 @@ void sm80_dispatch_shape(
|
||||
}
|
||||
}
|
||||
|
||||
// Dispatch shape for sm89 (L40S, L20, RTX 4090), according to:
|
||||
// https://github.com/vllm-project/vllm/blob/main/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh
|
||||
template <typename ElementOutput, typename ArchTag, typename InstructionShape>
|
||||
void sm89_dispatch_shape(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
int m = mat_a.size(0);
|
||||
int n = mat_b.size(1);
|
||||
if (m <= 16) {
|
||||
if (n <= 8192) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<16, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<16, 128, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
InstructionShape,
|
||||
4>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 32) {
|
||||
if (n <= 8192) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<32, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<32, 128, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
InstructionShape,
|
||||
4>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 64) {
|
||||
if (n <= 8192) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<64, 64, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<64, 128, 128>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
3>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 128) {
|
||||
if (n <= 8192) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<64, 128, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
InstructionShape,
|
||||
3>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else if (n <= 16384) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<64, 64, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 256) {
|
||||
if (n <= 4096) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<64, 128, 128>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
3>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else if (n <= 8192) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else if (n <= 16384) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<256, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
3>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<32, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename ElementOutput,
|
||||
typename TileShape,
|
||||
@@ -566,12 +703,23 @@ torch::Tensor int8_scaled_mm(
|
||||
sm75_dispatch_shape<cutlass::half_t, cutlass::arch::Sm75, cutlass::gemm::GemmShape<8, 8, 16>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else if (sm_version >= 80 && sm_version < 90) {
|
||||
if (out_dtype == torch::kBFloat16) {
|
||||
sm80_dispatch_shape<cutlass::bfloat16_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
// sm89 has a much smaller shared memory size (100K) than sm80 (160K)
|
||||
if (sm_version == 89) {
|
||||
if (out_dtype == torch::kBFloat16) {
|
||||
sm89_dispatch_shape<cutlass::bfloat16_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
sm89_dispatch_shape<cutlass::half_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else {
|
||||
sm80_dispatch_shape<cutlass::half_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
if (out_dtype == torch::kBFloat16) {
|
||||
sm80_dispatch_shape<cutlass::bfloat16_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
sm80_dispatch_shape<cutlass::half_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
}
|
||||
} else if (sm_version == 90) {
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
|
||||
@@ -37,7 +37,7 @@ class TestInt8Gemm(unittest.TestCase):
|
||||
print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")
|
||||
|
||||
def test_accuracy(self):
|
||||
Ms = [1, 128, 512, 1024, 4096, 8192]
|
||||
Ms = [1, 16, 32, 64, 128, 512, 1024, 4096, 8192]
|
||||
Ns = [16, 128, 512, 1024, 4096, 8192, 16384]
|
||||
Ks = [512, 1024, 4096, 8192, 16384]
|
||||
bias_opts = [True, False]
|
||||
|
||||
Reference in New Issue
Block a user