Fix triton_fused_moe unit test and benchmark (#9276)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
@@ -8,6 +8,8 @@ from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
||||
triton_kernel_moe_forward,
|
||||
)
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
@@ -92,8 +94,22 @@ class TestFusedMOE(CustomTestCase):
|
||||
w2_tri = w2_tri.transpose(-2, -1).contiguous()
|
||||
score = self.create_random_cuda_tensor((m, e), dtype)
|
||||
|
||||
topk_op = TopK(
|
||||
top_k=topk,
|
||||
renormalize=False,
|
||||
use_grouped_topk=False,
|
||||
)
|
||||
topk_op.use_triton_kernels = True
|
||||
triton_topk_output = topk_op.forward_cuda(
|
||||
hidden_states=a,
|
||||
router_logits=score,
|
||||
)
|
||||
|
||||
moe_runner_config = MoeRunnerConfig(
|
||||
inplace=False,
|
||||
)
|
||||
triton_output = triton_kernel_moe_forward(
|
||||
a, w1_tri, w2_tri, score, topk, renormalize=False
|
||||
a, w1_tri, w2_tri, triton_topk_output, moe_runner_config
|
||||
)
|
||||
torch_output = self.torch_naive_moe(a, w1, w2, score, topk)
|
||||
torch.testing.assert_close(triton_output, torch_output, rtol=rtol, atol=atol)
|
||||
|
||||
Reference in New Issue
Block a user