Fix triton_fused_moe unit test and benchmark (#9276)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
@@ -17,6 +17,8 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
||||
triton_kernel_moe_forward,
|
||||
)
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopK, TopKConfig, select_experts
|
||||
|
||||
|
||||
def get_model_config(model_name: str, tp_size: int):
|
||||
@@ -80,13 +82,26 @@ def fused_moe_triton_api(
|
||||
input_gating,
|
||||
topk,
|
||||
):
|
||||
topk_op = TopK(
|
||||
top_k=topk,
|
||||
renormalize=False,
|
||||
use_grouped_topk=False,
|
||||
)
|
||||
topk_op.use_triton_kernels = True
|
||||
triton_topk_output = topk_op.forward_cuda(
|
||||
hidden_states=x,
|
||||
router_logits=input_gating,
|
||||
)
|
||||
|
||||
moe_runner_config = MoeRunnerConfig(
|
||||
inplace=False,
|
||||
)
|
||||
return triton_kernel_moe_forward(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
renormalize=False,
|
||||
triton_topk_output,
|
||||
moe_runner_config,
|
||||
)
|
||||
|
||||
|
||||
@@ -103,14 +118,16 @@ def fused_moe_sglang_api(
|
||||
a2_scale=None,
|
||||
block_shape=None,
|
||||
):
|
||||
topk_output = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=input_gating,
|
||||
topk_config=TopKConfig(top_k=topk, renormalize=False),
|
||||
)
|
||||
return fused_moe_sglang(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
renormalize=False,
|
||||
inplace=True,
|
||||
topk_output,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
|
||||
@@ -8,6 +8,8 @@ from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
||||
triton_kernel_moe_forward,
|
||||
)
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
@@ -92,8 +94,22 @@ class TestFusedMOE(CustomTestCase):
|
||||
w2_tri = w2_tri.transpose(-2, -1).contiguous()
|
||||
score = self.create_random_cuda_tensor((m, e), dtype)
|
||||
|
||||
topk_op = TopK(
|
||||
top_k=topk,
|
||||
renormalize=False,
|
||||
use_grouped_topk=False,
|
||||
)
|
||||
topk_op.use_triton_kernels = True
|
||||
triton_topk_output = topk_op.forward_cuda(
|
||||
hidden_states=a,
|
||||
router_logits=score,
|
||||
)
|
||||
|
||||
moe_runner_config = MoeRunnerConfig(
|
||||
inplace=False,
|
||||
)
|
||||
triton_output = triton_kernel_moe_forward(
|
||||
a, w1_tri, w2_tri, score, topk, renormalize=False
|
||||
a, w1_tri, w2_tri, triton_topk_output, moe_runner_config
|
||||
)
|
||||
torch_output = self.torch_naive_moe(a, w1, w2, score, topk)
|
||||
torch.testing.assert_close(triton_output, torch_output, rtol=rtol, atol=atol)
|
||||
|
||||
Reference in New Issue
Block a user