Fix triton_fused_moe unit test and benchmark (#9276)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
@@ -17,6 +17,8 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
|||||||
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
||||||
triton_kernel_moe_forward,
|
triton_kernel_moe_forward,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||||
|
from sglang.srt.layers.moe.topk import TopK, TopKConfig, select_experts
|
||||||
|
|
||||||
|
|
||||||
def get_model_config(model_name: str, tp_size: int):
|
def get_model_config(model_name: str, tp_size: int):
|
||||||
@@ -80,13 +82,26 @@ def fused_moe_triton_api(
|
|||||||
input_gating,
|
input_gating,
|
||||||
topk,
|
topk,
|
||||||
):
|
):
|
||||||
|
topk_op = TopK(
|
||||||
|
top_k=topk,
|
||||||
|
renormalize=False,
|
||||||
|
use_grouped_topk=False,
|
||||||
|
)
|
||||||
|
topk_op.use_triton_kernels = True
|
||||||
|
triton_topk_output = topk_op.forward_cuda(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=input_gating,
|
||||||
|
)
|
||||||
|
|
||||||
|
moe_runner_config = MoeRunnerConfig(
|
||||||
|
inplace=False,
|
||||||
|
)
|
||||||
return triton_kernel_moe_forward(
|
return triton_kernel_moe_forward(
|
||||||
x,
|
x,
|
||||||
w1,
|
w1,
|
||||||
w2,
|
w2,
|
||||||
input_gating,
|
triton_topk_output,
|
||||||
topk,
|
moe_runner_config,
|
||||||
renormalize=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -103,14 +118,16 @@ def fused_moe_sglang_api(
|
|||||||
a2_scale=None,
|
a2_scale=None,
|
||||||
block_shape=None,
|
block_shape=None,
|
||||||
):
|
):
|
||||||
|
topk_output = select_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=input_gating,
|
||||||
|
topk_config=TopKConfig(top_k=topk, renormalize=False),
|
||||||
|
)
|
||||||
return fused_moe_sglang(
|
return fused_moe_sglang(
|
||||||
x,
|
x,
|
||||||
w1,
|
w1,
|
||||||
w2,
|
w2,
|
||||||
input_gating,
|
topk_output,
|
||||||
topk,
|
|
||||||
renormalize=False,
|
|
||||||
inplace=True,
|
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
w1_scale=w1_scale,
|
w1_scale=w1_scale,
|
||||||
w2_scale=w2_scale,
|
w2_scale=w2_scale,
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ from sglang.srt.layers.activation import SiluAndMul
|
|||||||
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
||||||
triton_kernel_moe_forward,
|
triton_kernel_moe_forward,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||||
|
from sglang.srt.layers.moe.topk import TopK
|
||||||
from sglang.test.test_utils import CustomTestCase
|
from sglang.test.test_utils import CustomTestCase
|
||||||
|
|
||||||
|
|
||||||
@@ -92,8 +94,22 @@ class TestFusedMOE(CustomTestCase):
|
|||||||
w2_tri = w2_tri.transpose(-2, -1).contiguous()
|
w2_tri = w2_tri.transpose(-2, -1).contiguous()
|
||||||
score = self.create_random_cuda_tensor((m, e), dtype)
|
score = self.create_random_cuda_tensor((m, e), dtype)
|
||||||
|
|
||||||
|
topk_op = TopK(
|
||||||
|
top_k=topk,
|
||||||
|
renormalize=False,
|
||||||
|
use_grouped_topk=False,
|
||||||
|
)
|
||||||
|
topk_op.use_triton_kernels = True
|
||||||
|
triton_topk_output = topk_op.forward_cuda(
|
||||||
|
hidden_states=a,
|
||||||
|
router_logits=score,
|
||||||
|
)
|
||||||
|
|
||||||
|
moe_runner_config = MoeRunnerConfig(
|
||||||
|
inplace=False,
|
||||||
|
)
|
||||||
triton_output = triton_kernel_moe_forward(
|
triton_output = triton_kernel_moe_forward(
|
||||||
a, w1_tri, w2_tri, score, topk, renormalize=False
|
a, w1_tri, w2_tri, triton_topk_output, moe_runner_config
|
||||||
)
|
)
|
||||||
torch_output = self.torch_naive_moe(a, w1, w2, score, topk)
|
torch_output = self.torch_naive_moe(a, w1, w2, score, topk)
|
||||||
torch.testing.assert_close(triton_output, torch_output, rtol=rtol, atol=atol)
|
torch.testing.assert_close(triton_output, torch_output, rtol=rtol, atol=atol)
|
||||||
|
|||||||
Reference in New Issue
Block a user