[1/N] MoE Refactor: refactor select_experts (#7966)

This commit is contained in:
Cheng Wan
2025-07-19 00:51:15 -07:00
committed by GitHub
parent cfab0ff6e2
commit 15ad6c9086
39 changed files with 556 additions and 871 deletions

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
import logging
from typing import Any, Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
@@ -31,6 +31,9 @@ from sglang.srt.layers.quantization.utils import (
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import is_cuda, next_power_of_2
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
if is_cuda():
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
@@ -402,15 +405,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
@@ -418,29 +414,12 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts
# Expert selection
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
topk_output=topk_output,
inplace=inplace,
activation=activation,
use_fp8_w8a8=True,
@@ -961,15 +940,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
@@ -982,21 +954,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."
from sglang.srt.layers.moe.topk import select_experts
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
if self.enable_flashinfer_moe:
assert (
@@ -1004,6 +961,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
), "apply_router_weight_on_input is not supported for Flashinfer"
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
# and fp4 quantized weights loaded from the checkpoint
topk_weights, topk_ids, _ = topk_output
output = flashinfer_cutlass_fused_moe(
x,
topk_ids.to(torch.int),
@@ -1029,6 +987,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
topk_weights, topk_ids, _ = topk_output
return cutlass_moe_fp4(
a=x,
a1_gscale=layer.w13_input_scale_quant,