Refactor TopK to ensure readability and extensibility (#9338)
This commit is contained in:
@@ -888,7 +888,7 @@ class DeepEPMoE(EPMoE):
|
|||||||
raise ValueError(f"Not Supported DeepEP format {dispatch_output.format}")
|
raise ValueError(f"Not Supported DeepEP format {dispatch_output.format}")
|
||||||
|
|
||||||
|
|
||||||
def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
|
def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
|
||||||
if get_moe_a2a_backend().is_deepep():
|
if get_moe_a2a_backend().is_deepep():
|
||||||
return DeepEPMoE
|
return DeepEPMoE
|
||||||
|
|
||||||
@@ -901,8 +901,7 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
|
|||||||
return FusedMoE
|
return FusedMoE
|
||||||
try:
|
try:
|
||||||
# Check the quantization argument directly
|
# Check the quantization argument directly
|
||||||
quantization = global_server_args_dict.get("quantization")
|
if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
|
||||||
if quantization == "modelopt_fp4":
|
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
||||||
FlashInferFP4MoE,
|
FlashInferFP4MoE,
|
||||||
)
|
)
|
||||||
@@ -911,7 +910,8 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if should_use_flashinfer_trtllm_moe():
|
if should_use_flashinfer_trtllm_moe() and quant_config is not None:
|
||||||
|
# FIXME: FlashInferFusedMoE only supports fp8 quant now
|
||||||
return FlashInferFusedMoE
|
return FlashInferFusedMoE
|
||||||
if get_moe_runner_backend().is_flashinfer_cutlass():
|
if get_moe_runner_backend().is_flashinfer_cutlass():
|
||||||
return FusedMoE
|
return FusedMoE
|
||||||
|
|||||||
@@ -74,16 +74,6 @@ if should_use_flashinfer_trtllm_moe():
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _is_fp4_quantization_enabled():
|
|
||||||
"""Check if ModelOpt FP4 quantization is enabled."""
|
|
||||||
try:
|
|
||||||
# Use the same simple check that works for class selection
|
|
||||||
quantization = global_server_args_dict.get("quantization")
|
|
||||||
return quantization == "modelopt_fp4"
|
|
||||||
except:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
|
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
|
||||||
# Guess tokens per expert assuming perfect expert distribution first.
|
# Guess tokens per expert assuming perfect expert distribution first.
|
||||||
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import math
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from typing import (
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
Callable,
|
Callable,
|
||||||
NamedTuple,
|
NamedTuple,
|
||||||
Optional,
|
Optional,
|
||||||
@@ -51,6 +52,9 @@ from sglang.srt.utils import (
|
|||||||
is_npu,
|
is_npu,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.quantization import QuantizationConfig
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
|
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -94,6 +98,7 @@ class TopKConfig:
|
|||||||
torch_native: bool = False
|
torch_native: bool = False
|
||||||
routed_scaling_factor: Optional[float] = None
|
routed_scaling_factor: Optional[float] = None
|
||||||
apply_routed_scaling_factor_on_output: bool = False
|
apply_routed_scaling_factor_on_output: bool = False
|
||||||
|
output_format: Optional[TopKOutputFormat] = None
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------- TopKOutput ---------------------------------------
|
# -------------------------------- TopKOutput ---------------------------------------
|
||||||
@@ -196,9 +201,10 @@ class TopK(CustomOp):
|
|||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
scoring_func: str = "softmax",
|
scoring_func: str = "softmax",
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
||||||
force_topk: bool = False,
|
output_format: Optional[TopKOutputFormat] = None,
|
||||||
):
|
):
|
||||||
# NOTE: scoring_func is not used for now, but we keep it for future use
|
# NOTE: scoring_func is not used for now, but we keep it for future use
|
||||||
# see https://github.com/sgl-project/sglang/pull/4505 for more details
|
# see https://github.com/sgl-project/sglang/pull/4505 for more details
|
||||||
@@ -207,6 +213,14 @@ class TopK(CustomOp):
|
|||||||
if use_grouped_topk:
|
if use_grouped_topk:
|
||||||
assert num_expert_group is not None and topk_group is not None
|
assert num_expert_group is not None and topk_group is not None
|
||||||
|
|
||||||
|
if (
|
||||||
|
quant_config is not None
|
||||||
|
and quant_config.get_name() == "modelopt_fp4"
|
||||||
|
and should_use_flashinfer_trtllm_moe()
|
||||||
|
):
|
||||||
|
# https://github.com/sgl-project/sglang/pull/9834#discussion_r2324480643
|
||||||
|
correction_bias = correction_bias.to(torch.bfloat16)
|
||||||
|
|
||||||
self.topk_config = TopKConfig(
|
self.topk_config = TopKConfig(
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
use_grouped_topk=use_grouped_topk,
|
use_grouped_topk=use_grouped_topk,
|
||||||
@@ -218,11 +232,9 @@ class TopK(CustomOp):
|
|||||||
correction_bias=correction_bias,
|
correction_bias=correction_bias,
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
|
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
|
||||||
|
output_format=output_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
|
|
||||||
self.force_topk = force_topk
|
|
||||||
|
|
||||||
def forward_native(
|
def forward_native(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@@ -248,7 +260,19 @@ class TopK(CustomOp):
|
|||||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||||
) -> TopKOutput:
|
) -> TopKOutput:
|
||||||
if self.use_triton_kernels:
|
if self.topk_config.output_format is not None:
|
||||||
|
output_format = self.topk_config.output_format
|
||||||
|
elif get_moe_runner_backend().is_triton_kernel():
|
||||||
|
output_format = TopKOutputFormat.TRITON_KERNEL
|
||||||
|
elif (
|
||||||
|
should_use_flashinfer_trtllm_moe()
|
||||||
|
or get_moe_runner_backend().is_flashinfer_mxfp4()
|
||||||
|
):
|
||||||
|
output_format = TopKOutputFormat.BYPASSED
|
||||||
|
else:
|
||||||
|
output_format = TopKOutputFormat.STANDARD
|
||||||
|
|
||||||
|
if output_format == TopKOutputFormat.TRITON_KERNEL:
|
||||||
# renormalize=True is equivalent to sm_first=False
|
# renormalize=True is equivalent to sm_first=False
|
||||||
routing_data, gather_idx, scatter_idx = routing(
|
routing_data, gather_idx, scatter_idx = routing(
|
||||||
router_logits,
|
router_logits,
|
||||||
@@ -256,10 +280,7 @@ class TopK(CustomOp):
|
|||||||
sm_first=not self.topk_config.renormalize,
|
sm_first=not self.topk_config.renormalize,
|
||||||
)
|
)
|
||||||
return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
|
return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
|
||||||
elif not self.force_topk and (
|
elif output_format == TopKOutputFormat.BYPASSED:
|
||||||
should_use_flashinfer_trtllm_moe()
|
|
||||||
or get_moe_runner_backend().is_flashinfer_mxfp4()
|
|
||||||
):
|
|
||||||
return BypassedTopKOutput(
|
return BypassedTopKOutput(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
|
|||||||
@@ -105,7 +105,6 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|||||||
"weight_loader_disable_mmap",
|
"weight_loader_disable_mmap",
|
||||||
"enable_multimodal",
|
"enable_multimodal",
|
||||||
"enable_symm_mem",
|
"enable_symm_mem",
|
||||||
"quantization",
|
|
||||||
"enable_custom_logit_processor",
|
"enable_custom_logit_processor",
|
||||||
"disaggregation_mode",
|
"disaggregation_mode",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -246,7 +246,7 @@ class BailingMoESparseMoeBlock(nn.Module):
|
|||||||
routed_scaling_factor=self.routed_scaling_factor,
|
routed_scaling_factor=self.routed_scaling_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.experts = get_moe_impl_class()(
|
self.experts = get_moe_impl_class(quant_config)(
|
||||||
num_experts=self.num_experts,
|
num_experts=self.num_experts,
|
||||||
top_k=self.top_k,
|
top_k=self.top_k,
|
||||||
layer_id=self.layer_id,
|
layer_id=self.layer_id,
|
||||||
|
|||||||
@@ -65,14 +65,10 @@ from sglang.srt.layers.moe import (
|
|||||||
get_deepep_mode,
|
get_deepep_mode,
|
||||||
get_moe_a2a_backend,
|
get_moe_a2a_backend,
|
||||||
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
||||||
should_use_flashinfer_trtllm_moe,
|
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||||
FusedMoE,
|
from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat
|
||||||
_is_fp4_quantization_enabled,
|
|
||||||
)
|
|
||||||
from sglang.srt.layers.moe.topk import TopK
|
|
||||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
@@ -375,10 +371,6 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
prefix=add_prefix("experts", prefix),
|
prefix=add_prefix("experts", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
correction_bias = self.gate.e_score_correction_bias
|
|
||||||
# https://github.com/sgl-project/sglang/pull/9834#discussion_r2324480643
|
|
||||||
if _is_fp4_quantization_enabled() and should_use_flashinfer_trtllm_moe():
|
|
||||||
correction_bias = correction_bias.to(torch.bfloat16)
|
|
||||||
self.topk = TopK(
|
self.topk = TopK(
|
||||||
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
||||||
renormalize=config.norm_topk_prob,
|
renormalize=config.norm_topk_prob,
|
||||||
@@ -386,10 +378,13 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
num_expert_group=config.n_group,
|
num_expert_group=config.n_group,
|
||||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||||
topk_group=config.topk_group,
|
topk_group=config.topk_group,
|
||||||
correction_bias=correction_bias,
|
correction_bias=self.gate.e_score_correction_bias,
|
||||||
|
quant_config=quant_config,
|
||||||
routed_scaling_factor=self.routed_scaling_factor,
|
routed_scaling_factor=self.routed_scaling_factor,
|
||||||
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
|
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
|
||||||
force_topk=quant_config is None,
|
# Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized
|
||||||
|
# and requires the output format to be standard. We use quant_config to determine the output format.
|
||||||
|
output_format=TopKOutputFormat.STANDARD if quant_config is None else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.shared_experts_is_int8 = False
|
self.shared_experts_is_int8 = False
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ class Ernie4Moe(nn.Module):
|
|||||||
correction_bias=self.gate.e_score_correction_bias,
|
correction_bias=self.gate.e_score_correction_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.experts = get_moe_impl_class()(
|
self.experts = get_moe_impl_class(quant_config)(
|
||||||
num_experts=config.moe_num_experts,
|
num_experts=config.moe_num_experts,
|
||||||
top_k=config.moe_k,
|
top_k=config.moe_k,
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
|
|||||||
@@ -429,7 +429,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|||||||
routed_scaling_factor=self.routed_scaling_factor,
|
routed_scaling_factor=self.routed_scaling_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.experts = get_moe_impl_class()(
|
self.experts = get_moe_impl_class(quant_config)(
|
||||||
num_experts=config.n_routed_experts
|
num_experts=config.n_routed_experts
|
||||||
+ self.num_fused_shared_experts
|
+ self.num_fused_shared_experts
|
||||||
+ global_server_args_dict["ep_num_redundant_experts"],
|
+ global_server_args_dict["ep_num_redundant_experts"],
|
||||||
|
|||||||
@@ -121,7 +121,7 @@ class GptOssSparseMoeBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.top_k = config.num_experts_per_tok
|
self.top_k = config.num_experts_per_tok
|
||||||
experts_type = get_moe_impl_class()
|
experts_type = get_moe_impl_class(quant_config)
|
||||||
extra_kwargs = {}
|
extra_kwargs = {}
|
||||||
if experts_type.__name__ == "FusedMoE":
|
if experts_type.__name__ == "FusedMoE":
|
||||||
quant_config_name = (
|
quant_config_name = (
|
||||||
|
|||||||
@@ -260,7 +260,7 @@ class LongcatFlashMoE(nn.Module):
|
|||||||
)
|
)
|
||||||
self.topk.forward = self.topk.forward_native
|
self.topk.forward = self.topk.forward_native
|
||||||
|
|
||||||
self.experts = get_moe_impl_class()(
|
self.experts = get_moe_impl_class(quant_config)(
|
||||||
num_experts=self.num_experts,
|
num_experts=self.num_experts,
|
||||||
top_k=self.top_k,
|
top_k=self.top_k,
|
||||||
layer_id=self.layer_id,
|
layer_id=self.layer_id,
|
||||||
@@ -853,7 +853,7 @@ class LongcatFlashForCausalLM(nn.Module):
|
|||||||
|
|
||||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||||
ckpt_gate_proj_name="gate_proj",
|
ckpt_gate_proj_name="gate_proj",
|
||||||
ckpt_down_proj_name="down_proj",
|
ckpt_down_proj_name="down_proj",
|
||||||
ckpt_up_proj_name="up_proj",
|
ckpt_up_proj_name="up_proj",
|
||||||
|
|||||||
@@ -143,7 +143,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|||||||
renormalize=config.norm_topk_prob,
|
renormalize=config.norm_topk_prob,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.experts = get_moe_impl_class()(
|
self.experts = get_moe_impl_class(quant_config)(
|
||||||
layer_id=self.layer_id,
|
layer_id=self.layer_id,
|
||||||
top_k=config.num_experts_per_tok,
|
top_k=config.num_experts_per_tok,
|
||||||
num_experts=config.num_experts,
|
num_experts=config.num_experts,
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
use_grouped_topk=False,
|
use_grouped_topk=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.experts = get_moe_impl_class()(
|
self.experts = get_moe_impl_class(quant_config)(
|
||||||
num_experts=config.num_experts
|
num_experts=config.num_experts
|
||||||
+ global_server_args_dict["ep_num_redundant_experts"],
|
+ global_server_args_dict["ep_num_redundant_experts"],
|
||||||
top_k=config.num_experts_per_tok,
|
top_k=config.num_experts_per_tok,
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from sglang.srt.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.rotary_embedding import get_rope
|
from sglang.srt.layers.rotary_embedding import get_rope
|
||||||
@@ -935,7 +935,7 @@ class Qwen3NextForCausalLM(nn.Module):
|
|||||||
|
|
||||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||||
ckpt_gate_proj_name="gate_proj",
|
ckpt_gate_proj_name="gate_proj",
|
||||||
ckpt_down_proj_name="down_proj",
|
ckpt_down_proj_name="down_proj",
|
||||||
ckpt_up_proj_name="up_proj",
|
ckpt_up_proj_name="up_proj",
|
||||||
|
|||||||
@@ -133,7 +133,7 @@ class Step3TextMoEMLP(nn.Module):
|
|||||||
use_grouped_topk=False,
|
use_grouped_topk=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.experts = get_moe_impl_class()(
|
self.experts = get_moe_impl_class(quant_config)(
|
||||||
num_experts=config.moe_num_experts,
|
num_experts=config.moe_num_experts,
|
||||||
top_k=config.moe_top_k,
|
top_k=config.moe_top_k,
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
|
|||||||
Reference in New Issue
Block a user