Add unit test for flashinfer fp4 moe (#8330)
Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
@@ -1,6 +1,9 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
|
||||||
from sgl_kernel import scaled_fp4_quant
|
from sgl_kernel import scaled_fp4_quant
|
||||||
|
|
||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
@@ -111,15 +114,16 @@ def torch_moe(a, w1, w2, score, topk, expert_map):
|
|||||||
).sum(dim=1)
|
).sum(dim=1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
def check_moe(
|
||||||
@pytest.mark.parametrize("e", [40, 64, 256])
|
m: int,
|
||||||
@pytest.mark.parametrize("topk", [1, 6, 8])
|
n: int,
|
||||||
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
|
k: int,
|
||||||
@torch.inference_mode()
|
e: int,
|
||||||
def test_cutlass_fp4_moe_no_graph(
|
topk: int,
|
||||||
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype
|
dtype: torch.dtype,
|
||||||
|
moe_impl: Callable,
|
||||||
|
flip_w13: bool,
|
||||||
):
|
):
|
||||||
|
|
||||||
torch.manual_seed(7)
|
torch.manual_seed(7)
|
||||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||||
@@ -167,38 +171,18 @@ def test_cutlass_fp4_moe_no_graph(
|
|||||||
|
|
||||||
a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
|
a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
|
||||||
a2_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
|
a2_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
|
||||||
# strides for the cutlass moe_fp4 kernel
|
test_output = moe_impl(
|
||||||
ab_strides_13 = torch.full(
|
|
||||||
(e,), w1_q.shape[2] * 2, dtype=torch.int64, device=w1_q.device
|
|
||||||
)
|
|
||||||
c_strides_13 = torch.full(
|
|
||||||
(e,), w1_q.shape[1], dtype=torch.int64, device=w1_q.device
|
|
||||||
)
|
|
||||||
ab_strides_2 = torch.full(
|
|
||||||
(e,), w2_q.shape[2] * 2, dtype=torch.int64, device=w2_q.device
|
|
||||||
)
|
|
||||||
c_strides_2 = torch.full((e,), w2_q.shape[1], dtype=torch.int64, device=w2_q.device)
|
|
||||||
params = CutlassMoEParams(
|
|
||||||
CutlassMoEType.BlockscaledFP4,
|
|
||||||
device=a.device,
|
|
||||||
num_experts=e,
|
|
||||||
intermediate_size_per_partition=n, # n
|
|
||||||
hidden_size=k,
|
|
||||||
) # k
|
|
||||||
cutlass_output = cutlass_moe_fp4(
|
|
||||||
a=a,
|
a=a,
|
||||||
a1_gscale=a1_gs,
|
|
||||||
w1_fp4=w1_q,
|
|
||||||
w1_blockscale=w1_blockscale,
|
|
||||||
w1_alphas=(1 / w1_gs),
|
|
||||||
a2_gscale=a2_gs,
|
|
||||||
w2_fp4=w2_q,
|
|
||||||
w2_blockscale=w2_blockscale,
|
|
||||||
w2_alphas=(1 / w2_gs),
|
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
params=params,
|
w1_q=w1_q,
|
||||||
apply_router_weight_on_input=False,
|
w2_q=w2_q,
|
||||||
|
a1_gs=a1_gs,
|
||||||
|
w1_blockscale=w1_blockscale,
|
||||||
|
w1_alphas=(1 / w1_gs),
|
||||||
|
a2_gs=a2_gs,
|
||||||
|
w2_blockscale=w2_blockscale,
|
||||||
|
w2_alphas=(1 / w2_gs),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reference check:
|
# Reference check:
|
||||||
@@ -237,10 +221,108 @@ def test_cutlass_fp4_moe_no_graph(
|
|||||||
block_size=quant_blocksize,
|
block_size=quant_blocksize,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if flip_w13:
|
||||||
|
dim = -2
|
||||||
|
size = w1_d.size(dim)
|
||||||
|
assert size % 2 == 0, f"Expected even size in dim {dim}, got {size}"
|
||||||
|
half = size // 2
|
||||||
|
# Reorder weight
|
||||||
|
w1, w3 = w1_d.split(half, dim=dim)
|
||||||
|
w1_d = torch.cat([w3, w1], dim=dim).contiguous()
|
||||||
|
|
||||||
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk, None)
|
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk, None)
|
||||||
|
|
||||||
torch.testing.assert_close(torch_output, cutlass_output, atol=1e-1, rtol=1e-1)
|
torch.testing.assert_close(torch_output, test_output, atol=1e-1, rtol=1e-1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||||
|
@pytest.mark.parametrize("e", [40, 64, 256])
|
||||||
|
@pytest.mark.parametrize("topk", [1, 6, 8])
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_cutlass_fp4_moe_no_graph(
|
||||||
|
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype
|
||||||
|
):
|
||||||
|
def cutlass_moe_impl(
|
||||||
|
a,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
w1_q,
|
||||||
|
w2_q,
|
||||||
|
a1_gs,
|
||||||
|
w1_blockscale,
|
||||||
|
w1_alphas,
|
||||||
|
a2_gs,
|
||||||
|
w2_blockscale,
|
||||||
|
w2_alphas,
|
||||||
|
):
|
||||||
|
params = CutlassMoEParams(
|
||||||
|
CutlassMoEType.BlockscaledFP4,
|
||||||
|
device=a.device,
|
||||||
|
num_experts=e,
|
||||||
|
intermediate_size_per_partition=n, # n
|
||||||
|
hidden_size=k,
|
||||||
|
) # k
|
||||||
|
return cutlass_moe_fp4(
|
||||||
|
a=a,
|
||||||
|
a1_gscale=a1_gs,
|
||||||
|
w1_fp4=w1_q,
|
||||||
|
w1_blockscale=w1_blockscale,
|
||||||
|
w1_alphas=w1_alphas,
|
||||||
|
a2_gscale=a2_gs,
|
||||||
|
w2_fp4=w2_q,
|
||||||
|
w2_blockscale=w2_blockscale,
|
||||||
|
w2_alphas=w2_alphas,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
params=params,
|
||||||
|
apply_router_weight_on_input=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
check_moe(m, n, k, e, topk, dtype, cutlass_moe_impl, flip_w13=False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||||
|
@pytest.mark.parametrize("e", [40, 64, 256])
|
||||||
|
@pytest.mark.parametrize("topk", [1, 6, 8])
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_flashinfer_fp4_moe_no_graph(
|
||||||
|
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype
|
||||||
|
):
|
||||||
|
def flashinfer_moe_impl(
|
||||||
|
a,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
w1_q,
|
||||||
|
w2_q,
|
||||||
|
a1_gs,
|
||||||
|
w1_blockscale,
|
||||||
|
w1_alphas,
|
||||||
|
a2_gs,
|
||||||
|
w2_blockscale,
|
||||||
|
w2_alphas,
|
||||||
|
):
|
||||||
|
return flashinfer_cutlass_fused_moe(
|
||||||
|
a,
|
||||||
|
topk_ids.to(torch.int),
|
||||||
|
topk_weights,
|
||||||
|
w1_q.view(torch.long),
|
||||||
|
w2_q.view(torch.long),
|
||||||
|
a.dtype,
|
||||||
|
quant_scales=[
|
||||||
|
a1_gs,
|
||||||
|
w1_blockscale.view(torch.int32),
|
||||||
|
w1_alphas,
|
||||||
|
a2_gs,
|
||||||
|
w2_blockscale.view(torch.int32),
|
||||||
|
w2_alphas,
|
||||||
|
],
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
check_moe(m, n, k, e, topk, dtype, flashinfer_moe_impl, flip_w13=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_cutlass_fp4_moe_no_graph(224, 1024, 1024, 256, 8, torch.half)
|
test_cutlass_fp4_moe_no_graph(224, 1024, 1024, 256, 8, torch.half)
|
||||||
|
test_flashinfer_fp4_moe_no_graph(224, 1024, 1024, 256, 8, torch.half)
|
||||||
|
|||||||
Reference in New Issue
Block a user