[6/N] MoE Refactor: Cleanup MoE-related configs (#8849)

This commit is contained in:
Cheng Wan
2025-08-14 21:14:53 -07:00
committed by GitHub
parent 584e1ab2d0
commit 295895120d
69 changed files with 956 additions and 1037 deletions

View File

@@ -12,7 +12,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
run_moe_ep_preproess,
silu_and_mul_triton_kernel,
)
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.test.test_utils import CustomTestCase
@@ -22,35 +22,26 @@ def ep_moe(
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
topk_config: TopKConfig,
# ep config
num_experts: int = 256,
fp8_dtype: torch.types = torch.float8_e4m3fn,
num_experts_per_partition: int = 128,
start_expert_id: int = 0,
end_expert_id: int = 127,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
use_fp8_w8a8: bool = False,
w1_scale_inv: Optional[torch.Tensor] = None,
w2_scale_inv: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
):
use_blockwise_fp8 = block_shape is not None
topk_weights, topk_ids, _ = select_experts(
top_k = topk_config.top_k
topk_output = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
# correction_bias=correction_bias, #skip this in test
custom_routing_function=custom_routing_function,
topk_config=topk_config,
)
topk_weights, topk_ids, _ = topk_output
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(topk_ids, num_experts)
@@ -294,14 +285,18 @@ class TestW8A8BlockFP8EPMoE(CustomTestCase):
start_id = cur_rank * num_experts_per_partition
end_id = start_id + num_experts_per_partition - 1
topk_config = TopKConfig(
top_k=topk,
renormalize=False,
)
with torch.inference_mode():
out = ep_moe(
hidden_states=a,
w1=w1,
w2=w2,
router_logits=score,
top_k=topk,
renormalize=False,
topk_config=topk_config,
use_fp8_w8a8=True,
w1_scale_inv=w1_s,
w2_scale_inv=w2_s,
@@ -316,8 +311,7 @@ class TestW8A8BlockFP8EPMoE(CustomTestCase):
w1=w1_ref,
w2=w2_ref,
router_logits=score,
top_k=topk,
renormalize=False,
topk_config=topk_config,
use_fp8_w8a8=False,
w1_scale_inv=None,
w2_scale_inv=None,