Add a unittest for fused_moe (#2416)
This commit is contained in:
@@ -10,7 +10,7 @@ Example usage:
|
|||||||
```bash
|
```bash
|
||||||
# Tune Qwen2-57B with FP8 and TP=4
|
# Tune Qwen2-57B with FP8 and TP=4
|
||||||
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
|
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
|
||||||
--model Qwen/Qwen2-57B-A14B-Instruct-FP8 \
|
--model Qwen/Qwen2-57B-A14B-Instruct \
|
||||||
--tp-size 4 \
|
--tp-size 4 \
|
||||||
--dtype fp8_w8a8 \
|
--dtype fp8_w8a8 \
|
||||||
--tune
|
--tune
|
||||||
@@ -34,7 +34,7 @@ python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_tri
|
|||||||
|
|
||||||
# Compare with FP8 mode for Qwen2-57B
|
# Compare with FP8 mode for Qwen2-57B
|
||||||
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
|
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
|
||||||
--model Qwen/Qwen2-57B-A14B-Instruct-FP8 \
|
--model Qwen/Qwen2-57B-A14B-Instruct \
|
||||||
--use-fp8
|
--use-fp8
|
||||||
|
|
||||||
# Compare with custom TP size
|
# Compare with custom TP size
|
||||||
@@ -43,3 +43,7 @@ python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_tri
|
|||||||
```
|
```
|
||||||
|
|
||||||
The benchmark results will be saved as plots and data files in the specified output directory (default: `./configs/benchmark_ops/vllm_sglang_fused_moe/`).
|
The benchmark results will be saved as plots and data files in the specified output directory (default: `./configs/benchmark_ops/vllm_sglang_fused_moe/`).
|
||||||
|
|
||||||
|
- `benchmark_torch_compile_fused_moe.py`: A tool for benchmarking the performance of the fused MoE kernel with `torch.compile` and original fused MoE kernel.
|
||||||
|
|
||||||
|
Usage is the same as `benchmark_vllm_vs_sglang_fused_moe_triton.py`.
|
||||||
|
|||||||
@@ -0,0 +1,146 @@
|
|||||||
|
{
|
||||||
|
"1": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 32,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"16": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"24": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"32": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"48": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"64": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"96": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"128": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"256": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"512": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"1024": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"1536": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"2048": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"3072": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"4096": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,146 @@
|
|||||||
|
{
|
||||||
|
"1": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 32,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"16": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"24": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"32": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"48": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"64": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"96": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"128": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"256": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"512": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"1024": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"1536": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"2048": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"3072": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"4096": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -15,6 +15,7 @@ suites = {
|
|||||||
"test_double_sparsity.py",
|
"test_double_sparsity.py",
|
||||||
"test_embedding_openai_server.py",
|
"test_embedding_openai_server.py",
|
||||||
"test_eval_accuracy_mini.py",
|
"test_eval_accuracy_mini.py",
|
||||||
|
"test_fused_moe.py",
|
||||||
"test_get_weights_by_name.py",
|
"test_get_weights_by_name.py",
|
||||||
"test_gguf.py",
|
"test_gguf.py",
|
||||||
"test_input_embeddings.py",
|
"test_input_embeddings.py",
|
||||||
|
|||||||
126
test/srt/test_fused_moe.py
Normal file
126
test/srt/test_fused_moe.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from vllm.model_executor.layers.fused_moe import fused_moe as fused_moe_vllm
|
||||||
|
|
||||||
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
|
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe
|
||||||
|
|
||||||
|
|
||||||
|
class TestFusedMOE(unittest.TestCase):
|
||||||
|
NUM_EXPERTS = [8, 64]
|
||||||
|
TOP_KS = [2, 6]
|
||||||
|
|
||||||
|
def torch_naive_moe(self, a, w1, w2, score, topk):
|
||||||
|
B, D = a.shape
|
||||||
|
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||||
|
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||||
|
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||||
|
topk_weight, topk_ids = torch.topk(score, topk)
|
||||||
|
topk_weight = topk_weight.view(-1)
|
||||||
|
topk_ids = topk_ids.view(-1)
|
||||||
|
for i in range(w1.shape[0]):
|
||||||
|
mask = topk_ids == i
|
||||||
|
if mask.sum():
|
||||||
|
out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[
|
||||||
|
i
|
||||||
|
].transpose(0, 1)
|
||||||
|
return (
|
||||||
|
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
||||||
|
).sum(dim=1)
|
||||||
|
|
||||||
|
def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False):
|
||||||
|
if use_fp8_w8a8:
|
||||||
|
# AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
|
||||||
|
capability = torch.cuda.get_device_capability()
|
||||||
|
if not (capability[0] >= 9 or capability == (8, 9)):
|
||||||
|
return
|
||||||
|
|
||||||
|
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||||
|
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||||
|
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||||
|
w1 = w1.to(torch.float8_e4m3fn)
|
||||||
|
w2 = w2.to(torch.float8_e4m3fn)
|
||||||
|
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||||
|
|
||||||
|
w1_scale = torch.randn(e, dtype=torch.float32, device="cuda")
|
||||||
|
w2_scale = torch.randn(e, dtype=torch.float32, device="cuda")
|
||||||
|
a1_scale = torch.randn(1, dtype=torch.float32, device="cuda")
|
||||||
|
a2_scale = torch.randn(1, dtype=torch.float32, device="cuda")
|
||||||
|
|
||||||
|
sglang_output = fused_moe(
|
||||||
|
a,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
score,
|
||||||
|
topk,
|
||||||
|
renormalize=False,
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
vllm_output = fused_moe_vllm(
|
||||||
|
a,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
score,
|
||||||
|
topk,
|
||||||
|
renormalize=False,
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.testing.assert_close(sglang_output, vllm_output, atol=2e-2, rtol=0)
|
||||||
|
|
||||||
|
else:
|
||||||
|
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||||
|
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||||
|
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||||
|
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||||
|
|
||||||
|
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
|
||||||
|
torch_output = self.torch_naive_moe(a, w1, w2, score, topk)
|
||||||
|
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
||||||
|
|
||||||
|
def test_various_configurations(self):
|
||||||
|
m_values = [1, 33, 64, 222, 1024 * 128]
|
||||||
|
n_values = [128, 1024, 2048]
|
||||||
|
k_values = [128, 511, 1024]
|
||||||
|
dtypes = [torch.float16, torch.bfloat16]
|
||||||
|
fp8_modes = [False, True]
|
||||||
|
|
||||||
|
for m in m_values:
|
||||||
|
for n in n_values:
|
||||||
|
for k in k_values:
|
||||||
|
for e in self.NUM_EXPERTS:
|
||||||
|
for topk in self.TOP_KS:
|
||||||
|
for dtype in dtypes:
|
||||||
|
for use_fp8_w8a8 in fp8_modes:
|
||||||
|
with self.subTest(
|
||||||
|
m=m,
|
||||||
|
n=n,
|
||||||
|
k=k,
|
||||||
|
e=e,
|
||||||
|
topk=topk,
|
||||||
|
dtype=dtype,
|
||||||
|
fp8=use_fp8_w8a8,
|
||||||
|
):
|
||||||
|
self._test_case(
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
e,
|
||||||
|
topk,
|
||||||
|
dtype,
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user