[6/N] MoE Refactor: Cleanup MoE-related configs (#8849)

This commit is contained in:
Cheng Wan
2025-08-14 21:14:53 -07:00
committed by GitHub
parent 584e1ab2d0
commit 295895120d
69 changed files with 956 additions and 1037 deletions

View File

@@ -7,8 +7,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from sglang.srt.layers.moe import should_use_flashinfer_trtllm_moe
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
@@ -30,10 +30,11 @@ from sglang.srt.layers.quantization.utils import (
requantize_with_max_scale,
)
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import is_cuda, next_power_of_2
if TYPE_CHECKING:
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
if is_cuda():
@@ -422,12 +423,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
@@ -436,15 +432,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer.w13_weight,
layer.w2_weight,
topk_output=topk_output,
inplace=inplace,
activation=activation,
moe_runner_config=moe_runner_config,
use_fp8_w8a8=True,
per_channel_quant=False, # ModelOpt uses per-tensor quantization
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
no_combine=no_combine,
)
@@ -741,8 +735,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
@property
def enable_flashinfer_cutlass_moe(self) -> bool:
from sglang.srt.layers.moe import get_moe_runner_backend
"""Access the global enable_flashinfer_cutlass_moe setting."""
return global_server_args_dict.get("enable_flashinfer_cutlass_moe", False)
return get_moe_runner_backend().is_flashinfer_cutlass()
def create_weights(
self,
@@ -1160,21 +1156,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
def apply(
self,
layer: torch.nn.Module,
layer: FusedMoE,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
ep_rank: Optional[int] = None,
ep_size: Optional[int] = None,
tp_rank: Optional[int] = None,
tp_size: Optional[int] = None,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."
assert (
moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
# Check if this is a FlashInferFP4MoE layer that should handle its own forward
if hasattr(layer, "gemm1_weights_fp4_shuffled"):
@@ -1183,7 +1172,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
if self.enable_flashinfer_cutlass_moe:
assert (
not apply_router_weight_on_input
not moe_runner_config.apply_router_weight_on_input
), "apply_router_weight_on_input is not supported for Flashinfer"
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
# and fp4 quantized weights loaded from the checkpoint
@@ -1205,14 +1194,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
layer.w2_blockscale_swizzled.view(torch.int32),
layer.g2_alphas,
],
ep_size=ep_size,
ep_rank=ep_rank,
tp_size=tp_size,
tp_rank=tp_rank,
ep_size=layer.moe_ep_size,
ep_rank=layer.moe_ep_rank,
tp_size=layer.moe_tp_size,
tp_rank=layer.moe_tp_rank,
tune_max_num_tokens=next_power_of_2(x.shape[0]),
)[0]
if routed_scaling_factor is not None:
output *= routed_scaling_factor
if moe_runner_config.routed_scaling_factor is not None:
output *= moe_runner_config.routed_scaling_factor
return output
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
@@ -1231,8 +1220,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
topk_weights=topk_weights,
topk_ids=topk_ids,
params=layer.cutlass_moe_params,
apply_router_weight_on_input=apply_router_weight_on_input,
apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
).to(x.dtype)
if routed_scaling_factor is not None:
output *= routed_scaling_factor
if moe_runner_config.routed_scaling_factor is not None:
output *= moe_runner_config.routed_scaling_factor
return output