Refactor TopK to ensure readability and extensibility (#9338)
This commit is contained in:
@@ -888,7 +888,7 @@ class DeepEPMoE(EPMoE):
|
||||
raise ValueError(f"Not Supported DeepEP format {dispatch_output.format}")
|
||||
|
||||
|
||||
def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
|
||||
def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
|
||||
if get_moe_a2a_backend().is_deepep():
|
||||
return DeepEPMoE
|
||||
|
||||
@@ -901,8 +901,7 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
|
||||
return FusedMoE
|
||||
try:
|
||||
# Check the quantization argument directly
|
||||
quantization = global_server_args_dict.get("quantization")
|
||||
if quantization == "modelopt_fp4":
|
||||
if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
||||
FlashInferFP4MoE,
|
||||
)
|
||||
@@ -911,7 +910,8 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
|
||||
except:
|
||||
pass
|
||||
|
||||
if should_use_flashinfer_trtllm_moe():
|
||||
if should_use_flashinfer_trtllm_moe() and quant_config is not None:
|
||||
# FIXME: FlashInferFusedMoE only supports fp8 quant now
|
||||
return FlashInferFusedMoE
|
||||
if get_moe_runner_backend().is_flashinfer_cutlass():
|
||||
return FusedMoE
|
||||
|
||||
@@ -74,16 +74,6 @@ if should_use_flashinfer_trtllm_moe():
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _is_fp4_quantization_enabled():
|
||||
"""Check if ModelOpt FP4 quantization is enabled."""
|
||||
try:
|
||||
# Use the same simple check that works for class selection
|
||||
quantization = global_server_args_dict.get("quantization")
|
||||
return quantization == "modelopt_fp4"
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
|
||||
# Guess tokens per expert assuming perfect expert distribution first.
|
||||
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
||||
|
||||
@@ -19,6 +19,7 @@ import math
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Callable,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
@@ -51,6 +52,9 @@ from sglang.srt.utils import (
|
||||
is_npu,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.quantization import QuantizationConfig
|
||||
|
||||
try:
|
||||
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
|
||||
except ImportError:
|
||||
@@ -94,6 +98,7 @@ class TopKConfig:
|
||||
torch_native: bool = False
|
||||
routed_scaling_factor: Optional[float] = None
|
||||
apply_routed_scaling_factor_on_output: bool = False
|
||||
output_format: Optional[TopKOutputFormat] = None
|
||||
|
||||
|
||||
# -------------------------------- TopKOutput ---------------------------------------
|
||||
@@ -196,9 +201,10 @@ class TopK(CustomOp):
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
||||
force_topk: bool = False,
|
||||
output_format: Optional[TopKOutputFormat] = None,
|
||||
):
|
||||
# NOTE: scoring_func is not used for now, but we keep it for future use
|
||||
# see https://github.com/sgl-project/sglang/pull/4505 for more details
|
||||
@@ -207,6 +213,14 @@ class TopK(CustomOp):
|
||||
if use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
|
||||
if (
|
||||
quant_config is not None
|
||||
and quant_config.get_name() == "modelopt_fp4"
|
||||
and should_use_flashinfer_trtllm_moe()
|
||||
):
|
||||
# https://github.com/sgl-project/sglang/pull/9834#discussion_r2324480643
|
||||
correction_bias = correction_bias.to(torch.bfloat16)
|
||||
|
||||
self.topk_config = TopKConfig(
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
@@ -218,11 +232,9 @@ class TopK(CustomOp):
|
||||
correction_bias=correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
|
||||
output_format=output_format,
|
||||
)
|
||||
|
||||
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
|
||||
self.force_topk = force_topk
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -248,7 +260,19 @@ class TopK(CustomOp):
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
) -> TopKOutput:
|
||||
if self.use_triton_kernels:
|
||||
if self.topk_config.output_format is not None:
|
||||
output_format = self.topk_config.output_format
|
||||
elif get_moe_runner_backend().is_triton_kernel():
|
||||
output_format = TopKOutputFormat.TRITON_KERNEL
|
||||
elif (
|
||||
should_use_flashinfer_trtllm_moe()
|
||||
or get_moe_runner_backend().is_flashinfer_mxfp4()
|
||||
):
|
||||
output_format = TopKOutputFormat.BYPASSED
|
||||
else:
|
||||
output_format = TopKOutputFormat.STANDARD
|
||||
|
||||
if output_format == TopKOutputFormat.TRITON_KERNEL:
|
||||
# renormalize=True is equivalent to sm_first=False
|
||||
routing_data, gather_idx, scatter_idx = routing(
|
||||
router_logits,
|
||||
@@ -256,10 +280,7 @@ class TopK(CustomOp):
|
||||
sm_first=not self.topk_config.renormalize,
|
||||
)
|
||||
return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
|
||||
elif not self.force_topk and (
|
||||
should_use_flashinfer_trtllm_moe()
|
||||
or get_moe_runner_backend().is_flashinfer_mxfp4()
|
||||
):
|
||||
elif output_format == TopKOutputFormat.BYPASSED:
|
||||
return BypassedTopKOutput(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
|
||||
@@ -105,7 +105,6 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
||||
"weight_loader_disable_mmap",
|
||||
"enable_multimodal",
|
||||
"enable_symm_mem",
|
||||
"quantization",
|
||||
"enable_custom_logit_processor",
|
||||
"disaggregation_mode",
|
||||
]
|
||||
|
||||
@@ -246,7 +246,7 @@ class BailingMoESparseMoeBlock(nn.Module):
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
)
|
||||
|
||||
self.experts = get_moe_impl_class()(
|
||||
self.experts = get_moe_impl_class(quant_config)(
|
||||
num_experts=self.num_experts,
|
||||
top_k=self.top_k,
|
||||
layer_id=self.layer_id,
|
||||
|
||||
@@ -65,14 +65,10 @@ from sglang.srt.layers.moe import (
|
||||
get_deepep_mode,
|
||||
get_moe_a2a_backend,
|
||||
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
||||
should_use_flashinfer_trtllm_moe,
|
||||
)
|
||||
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
||||
FusedMoE,
|
||||
_is_fp4_quantization_enabled,
|
||||
)
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
@@ -375,10 +371,6 @@ class DeepseekV2MoE(nn.Module):
|
||||
prefix=add_prefix("experts", prefix),
|
||||
)
|
||||
|
||||
correction_bias = self.gate.e_score_correction_bias
|
||||
# https://github.com/sgl-project/sglang/pull/9834#discussion_r2324480643
|
||||
if _is_fp4_quantization_enabled() and should_use_flashinfer_trtllm_moe():
|
||||
correction_bias = correction_bias.to(torch.bfloat16)
|
||||
self.topk = TopK(
|
||||
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
||||
renormalize=config.norm_topk_prob,
|
||||
@@ -386,10 +378,13 @@ class DeepseekV2MoE(nn.Module):
|
||||
num_expert_group=config.n_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
topk_group=config.topk_group,
|
||||
correction_bias=correction_bias,
|
||||
correction_bias=self.gate.e_score_correction_bias,
|
||||
quant_config=quant_config,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
|
||||
force_topk=quant_config is None,
|
||||
# Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized
|
||||
# and requires the output format to be standard. We use quant_config to determine the output format.
|
||||
output_format=TopKOutputFormat.STANDARD if quant_config is None else None,
|
||||
)
|
||||
|
||||
self.shared_experts_is_int8 = False
|
||||
|
||||
@@ -92,7 +92,7 @@ class Ernie4Moe(nn.Module):
|
||||
correction_bias=self.gate.e_score_correction_bias,
|
||||
)
|
||||
|
||||
self.experts = get_moe_impl_class()(
|
||||
self.experts = get_moe_impl_class(quant_config)(
|
||||
num_experts=config.moe_num_experts,
|
||||
top_k=config.moe_k,
|
||||
hidden_size=config.hidden_size,
|
||||
|
||||
@@ -429,7 +429,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
)
|
||||
|
||||
self.experts = get_moe_impl_class()(
|
||||
self.experts = get_moe_impl_class(quant_config)(
|
||||
num_experts=config.n_routed_experts
|
||||
+ self.num_fused_shared_experts
|
||||
+ global_server_args_dict["ep_num_redundant_experts"],
|
||||
|
||||
@@ -121,7 +121,7 @@ class GptOssSparseMoeBlock(nn.Module):
|
||||
)
|
||||
|
||||
self.top_k = config.num_experts_per_tok
|
||||
experts_type = get_moe_impl_class()
|
||||
experts_type = get_moe_impl_class(quant_config)
|
||||
extra_kwargs = {}
|
||||
if experts_type.__name__ == "FusedMoE":
|
||||
quant_config_name = (
|
||||
|
||||
@@ -260,7 +260,7 @@ class LongcatFlashMoE(nn.Module):
|
||||
)
|
||||
self.topk.forward = self.topk.forward_native
|
||||
|
||||
self.experts = get_moe_impl_class()(
|
||||
self.experts = get_moe_impl_class(quant_config)(
|
||||
num_experts=self.num_experts,
|
||||
top_k=self.top_k,
|
||||
layer_id=self.layer_id,
|
||||
@@ -853,7 +853,7 @@ class LongcatFlashForCausalLM(nn.Module):
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
|
||||
@@ -143,7 +143,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
renormalize=config.norm_topk_prob,
|
||||
)
|
||||
|
||||
self.experts = get_moe_impl_class()(
|
||||
self.experts = get_moe_impl_class(quant_config)(
|
||||
layer_id=self.layer_id,
|
||||
top_k=config.num_experts_per_tok,
|
||||
num_experts=config.num_experts,
|
||||
|
||||
@@ -98,7 +98,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
use_grouped_topk=False,
|
||||
)
|
||||
|
||||
self.experts = get_moe_impl_class()(
|
||||
self.experts = get_moe_impl_class(quant_config)(
|
||||
num_experts=config.num_experts
|
||||
+ global_server_args_dict["ep_num_redundant_experts"],
|
||||
top_k=config.num_experts_per_tok,
|
||||
|
||||
@@ -30,7 +30,7 @@ from sglang.srt.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope
|
||||
@@ -935,7 +935,7 @@ class Qwen3NextForCausalLM(nn.Module):
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
|
||||
@@ -133,7 +133,7 @@ class Step3TextMoEMLP(nn.Module):
|
||||
use_grouped_topk=False,
|
||||
)
|
||||
|
||||
self.experts = get_moe_impl_class()(
|
||||
self.experts = get_moe_impl_class(quant_config)(
|
||||
num_experts=config.moe_num_experts,
|
||||
top_k=config.moe_top_k,
|
||||
hidden_size=config.hidden_size,
|
||||
|
||||
Reference in New Issue
Block a user