Support serving DeepSeek-R1-Channel-INT8 with 32 L40S. (#4418)
This commit is contained in:
@@ -218,6 +218,33 @@ python3 -m sglang.bench_serving --dataset-path /path/to/ShareGPT_V3_unfiltered_c
|
|||||||
|
|
||||||
> **Note: using `--parallel 200` can accelerate accuracy benchmarking**.
|
> **Note: using `--parallel 200` can accelerate accuracy benchmarking**.
|
||||||
|
|
||||||
|
### Example: Serving with 32 L40S with int8 Quantization
|
||||||
|
|
||||||
|
Running with per-channel quantization model:
|
||||||
|
|
||||||
|
- [meituan/DeepSeek-R1-Channel-INT8](https://huggingface.co/meituan/DeepSeek-R1-Channel-INT8)
|
||||||
|
|
||||||
|
Assuming that master node IP is `MASTER_IP`, checkpoint path is `/path/to/DeepSeek-R1-Channel-INT8` and port=5000, we can have following commands to launch the server:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
#master
|
||||||
|
python3 -m sglang.launch_server --model meituan/DeepSeek-R1-Channel-INT8 --tp 32 --quantization w8a8_int8 \
|
||||||
|
--dist-init-addr MASTER_IP:5000 --nnodes 4 --node-rank 0 --trust-remote \
|
||||||
|
--enable-torch-compile --torch-compile-max-bs 32
|
||||||
|
#cluster
|
||||||
|
python3 -m sglang.launch_server --model meituan/DeepSeek-R1-Channel-INT8 --tp 32 --quantization w8a8_int8 \
|
||||||
|
--dist-init-addr MASTER_IP:5000 --nnodes 4 --node-rank 1 --trust-remote \
|
||||||
|
--enable-torch-compile --torch-compile-max-bs 32
|
||||||
|
python3 -m sglang.launch_server --model meituan/DeepSeek-R1-Channel-INT8 --tp 32 --quantization w8a8_int8 \
|
||||||
|
--dist-init-addr MASTER_IP:5000 --nnodes 4 --node-rank 2 --trust-remote \
|
||||||
|
--enable-torch-compile --torch-compile-max-bs 32
|
||||||
|
python3 -m sglang.launch_server --model meituan/DeepSeek-R1-Channel-INT8 --tp 32 --quantization w8a8_int8 \
|
||||||
|
--dist-init-addr MASTER_IP:5000 --nnodes 4 --node-rank 3 --trust-remote \
|
||||||
|
--enable-torch-compile --torch-compile-max-bs 32
|
||||||
|
```
|
||||||
|
|
||||||
|
The benchmarking method is the same as describted in the previous [16 x A100](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-16-a100a800-with-int8-quantization) example.
|
||||||
|
|
||||||
### Example: Serving on any cloud or Kubernetes with SkyPilot
|
### Example: Serving on any cloud or Kubernetes with SkyPilot
|
||||||
|
|
||||||
SkyPilot helps find cheapest available GPUs across any cloud or existing Kubernetes clusters and launch distributed serving with a single command. See details [here](https://github.com/skypilot-org/skypilot/tree/master/llm/deepseek-r1).
|
SkyPilot helps find cheapest available GPUs across any cloud or existing Kubernetes clusters and launch distributed serving with a single command. See details [here](https://github.com/skypilot-org/skypilot/tree/master/llm/deepseek-r1).
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ SGLang is recognized as one of the top engines for [DeepSeek model inference](ht
|
|||||||
| **Quantized weights (AWQ)** | 8 x H100/800/20 |
|
| **Quantized weights (AWQ)** | 8 x H100/800/20 |
|
||||||
| | 8 x A100/A800 |
|
| | 8 x A100/A800 |
|
||||||
| **Quantized weights (int8)** | 16 x A100/800 |
|
| **Quantized weights (int8)** | 16 x A100/800 |
|
||||||
|
| | 32 x L40S |
|
||||||
|
|
||||||
<style>
|
<style>
|
||||||
.md-typeset__table {
|
.md-typeset__table {
|
||||||
@@ -56,6 +57,7 @@ Detailed commands for reference:
|
|||||||
- [4 x 8 x A100](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-four-a1008-nodes)
|
- [4 x 8 x A100](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-four-a1008-nodes)
|
||||||
- [8 x A100 (AWQ)](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-8-a100a800-with-awq-quantization)
|
- [8 x A100 (AWQ)](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-8-a100a800-with-awq-quantization)
|
||||||
- [16 x A100 (int8)](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-16-a100a800-with-int8-quantization)
|
- [16 x A100 (int8)](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-16-a100a800-with-int8-quantization)
|
||||||
|
- [32 x L40S (int8)](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-32-l40s-with-int8-quantization)
|
||||||
|
|
||||||
### Download Weights
|
### Download Weights
|
||||||
|
|
||||||
|
|||||||
@@ -341,6 +341,15 @@ def extend_attention_fwd(
|
|||||||
else:
|
else:
|
||||||
BLOCK_M, BLOCK_N = (32, 64)
|
BLOCK_M, BLOCK_N = (32, 64)
|
||||||
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
|
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
|
||||||
|
# 8.9 has a much smaller shared memory size (100K) than 8.0 (160K)
|
||||||
|
if CUDA_CAPABILITY[1] == 9:
|
||||||
|
if Lq <= 128:
|
||||||
|
BLOCK_M, BLOCK_N = (64, 128)
|
||||||
|
elif Lq <= 256:
|
||||||
|
BLOCK_M, BLOCK_N = (64, 64)
|
||||||
|
else:
|
||||||
|
BLOCK_M, BLOCK_N = (32, 32)
|
||||||
|
else:
|
||||||
if Lq <= 128:
|
if Lq <= 128:
|
||||||
BLOCK_M, BLOCK_N = (128, 128)
|
BLOCK_M, BLOCK_N = (128, 128)
|
||||||
elif Lq <= 256:
|
elif Lq <= 256:
|
||||||
|
|||||||
@@ -0,0 +1,146 @@
|
|||||||
|
{
|
||||||
|
"1": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"16": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"24": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"32": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"48": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 32,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"64": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 32,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"96": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 32,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"128": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 32,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"256": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"512": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"1024": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"1536": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"2048": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"3072": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"4096": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,146 @@
|
|||||||
|
{
|
||||||
|
"1": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"16": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"24": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"32": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"48": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"64": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"96": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"128": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"256": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"512": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"1024": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"1536": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"2048": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"3072": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"4096": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -279,6 +279,143 @@ void sm80_dispatch_shape(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Dispatch shape for sm89 (L40S, L20, RTX 4090), according to:
|
||||||
|
// https://github.com/vllm-project/vllm/blob/main/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh
|
||||||
|
template <typename ElementOutput, typename ArchTag, typename InstructionShape>
|
||||||
|
void sm89_dispatch_shape(
|
||||||
|
torch::Tensor& out,
|
||||||
|
const torch::Tensor& mat_a,
|
||||||
|
const torch::Tensor& mat_b,
|
||||||
|
const torch::Tensor& scales_a,
|
||||||
|
const torch::Tensor& scales_b,
|
||||||
|
const c10::optional<torch::Tensor>& bias) {
|
||||||
|
int m = mat_a.size(0);
|
||||||
|
int n = mat_b.size(1);
|
||||||
|
if (m <= 16) {
|
||||||
|
if (n <= 8192) {
|
||||||
|
cutlass_int8_scaled_mm<
|
||||||
|
ElementOutput,
|
||||||
|
ArchTag,
|
||||||
|
cutlass::gemm::GemmShape<16, 64, 128>,
|
||||||
|
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||||
|
InstructionShape,
|
||||||
|
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||||
|
} else {
|
||||||
|
cutlass_int8_scaled_mm<
|
||||||
|
ElementOutput,
|
||||||
|
ArchTag,
|
||||||
|
cutlass::gemm::GemmShape<16, 128, 128>,
|
||||||
|
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||||
|
InstructionShape,
|
||||||
|
4>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||||
|
}
|
||||||
|
} else if (m <= 32) {
|
||||||
|
if (n <= 8192) {
|
||||||
|
cutlass_int8_scaled_mm<
|
||||||
|
ElementOutput,
|
||||||
|
ArchTag,
|
||||||
|
cutlass::gemm::GemmShape<32, 64, 128>,
|
||||||
|
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||||
|
InstructionShape,
|
||||||
|
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||||
|
} else {
|
||||||
|
cutlass_int8_scaled_mm<
|
||||||
|
ElementOutput,
|
||||||
|
ArchTag,
|
||||||
|
cutlass::gemm::GemmShape<32, 128, 128>,
|
||||||
|
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||||
|
InstructionShape,
|
||||||
|
4>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||||
|
}
|
||||||
|
} else if (m <= 64) {
|
||||||
|
if (n <= 8192) {
|
||||||
|
cutlass_int8_scaled_mm<
|
||||||
|
ElementOutput,
|
||||||
|
ArchTag,
|
||||||
|
cutlass::gemm::GemmShape<64, 64, 128>,
|
||||||
|
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||||
|
InstructionShape,
|
||||||
|
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||||
|
} else {
|
||||||
|
cutlass_int8_scaled_mm<
|
||||||
|
ElementOutput,
|
||||||
|
ArchTag,
|
||||||
|
cutlass::gemm::GemmShape<64, 128, 128>,
|
||||||
|
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||||
|
InstructionShape,
|
||||||
|
3>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||||
|
}
|
||||||
|
} else if (m <= 128) {
|
||||||
|
if (n <= 8192) {
|
||||||
|
cutlass_int8_scaled_mm<
|
||||||
|
ElementOutput,
|
||||||
|
ArchTag,
|
||||||
|
cutlass::gemm::GemmShape<64, 128, 128>,
|
||||||
|
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||||
|
InstructionShape,
|
||||||
|
3>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||||
|
} else if (n <= 16384) {
|
||||||
|
cutlass_int8_scaled_mm<
|
||||||
|
ElementOutput,
|
||||||
|
ArchTag,
|
||||||
|
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||||
|
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||||
|
InstructionShape,
|
||||||
|
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||||
|
} else {
|
||||||
|
cutlass_int8_scaled_mm<
|
||||||
|
ElementOutput,
|
||||||
|
ArchTag,
|
||||||
|
cutlass::gemm::GemmShape<64, 64, 128>,
|
||||||
|
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||||
|
InstructionShape,
|
||||||
|
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||||
|
}
|
||||||
|
} else if (m <= 256) {
|
||||||
|
if (n <= 4096) {
|
||||||
|
cutlass_int8_scaled_mm<
|
||||||
|
ElementOutput,
|
||||||
|
ArchTag,
|
||||||
|
cutlass::gemm::GemmShape<64, 128, 128>,
|
||||||
|
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||||
|
InstructionShape,
|
||||||
|
3>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||||
|
} else if (n <= 8192) {
|
||||||
|
cutlass_int8_scaled_mm<
|
||||||
|
ElementOutput,
|
||||||
|
ArchTag,
|
||||||
|
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||||
|
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||||
|
InstructionShape,
|
||||||
|
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||||
|
} else if (n <= 16384) {
|
||||||
|
cutlass_int8_scaled_mm<
|
||||||
|
ElementOutput,
|
||||||
|
ArchTag,
|
||||||
|
cutlass::gemm::GemmShape<256, 128, 64>,
|
||||||
|
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||||
|
InstructionShape,
|
||||||
|
3>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||||
|
} else {
|
||||||
|
cutlass_int8_scaled_mm<
|
||||||
|
ElementOutput,
|
||||||
|
ArchTag,
|
||||||
|
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||||
|
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||||
|
InstructionShape,
|
||||||
|
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cutlass_int8_scaled_mm<
|
||||||
|
ElementOutput,
|
||||||
|
ArchTag,
|
||||||
|
cutlass::gemm::GemmShape<32, 64, 128>,
|
||||||
|
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||||
|
InstructionShape,
|
||||||
|
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <
|
template <
|
||||||
typename ElementOutput,
|
typename ElementOutput,
|
||||||
typename TileShape,
|
typename TileShape,
|
||||||
@@ -566,6 +703,16 @@ torch::Tensor int8_scaled_mm(
|
|||||||
sm75_dispatch_shape<cutlass::half_t, cutlass::arch::Sm75, cutlass::gemm::GemmShape<8, 8, 16>>(
|
sm75_dispatch_shape<cutlass::half_t, cutlass::arch::Sm75, cutlass::gemm::GemmShape<8, 8, 16>>(
|
||||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||||
} else if (sm_version >= 80 && sm_version < 90) {
|
} else if (sm_version >= 80 && sm_version < 90) {
|
||||||
|
// sm89 has a much smaller shared memory size (100K) than sm80 (160K)
|
||||||
|
if (sm_version == 89) {
|
||||||
|
if (out_dtype == torch::kBFloat16) {
|
||||||
|
sm89_dispatch_shape<cutlass::bfloat16_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
|
||||||
|
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||||
|
} else {
|
||||||
|
sm89_dispatch_shape<cutlass::half_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
|
||||||
|
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
if (out_dtype == torch::kBFloat16) {
|
if (out_dtype == torch::kBFloat16) {
|
||||||
sm80_dispatch_shape<cutlass::bfloat16_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
|
sm80_dispatch_shape<cutlass::bfloat16_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
|
||||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||||
@@ -573,6 +720,7 @@ torch::Tensor int8_scaled_mm(
|
|||||||
sm80_dispatch_shape<cutlass::half_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
|
sm80_dispatch_shape<cutlass::half_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
|
||||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
} else if (sm_version == 90) {
|
} else if (sm_version == 90) {
|
||||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||||
// cutlass 3.x
|
// cutlass 3.x
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ class TestInt8Gemm(unittest.TestCase):
|
|||||||
print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")
|
print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")
|
||||||
|
|
||||||
def test_accuracy(self):
|
def test_accuracy(self):
|
||||||
Ms = [1, 128, 512, 1024, 4096, 8192]
|
Ms = [1, 16, 32, 64, 128, 512, 1024, 4096, 8192]
|
||||||
Ns = [16, 128, 512, 1024, 4096, 8192, 16384]
|
Ns = [16, 128, 512, 1024, 4096, 8192, 16384]
|
||||||
Ks = [512, 1024, 4096, 8192, 16384]
|
Ks = [512, 1024, 4096, 8192, 16384]
|
||||||
bias_opts = [True, False]
|
bias_opts = [True, False]
|
||||||
|
|||||||
Reference in New Issue
Block a user