[6/N] MoE Refactor: Cleanup MoE-related configs (#8849)

This commit is contained in:
Cheng Wan
2025-08-14 21:14:53 -07:00
committed by GitHub
parent 584e1ab2d0
commit 295895120d
69 changed files with 956 additions and 1037 deletions

View File

@@ -16,14 +16,13 @@
from __future__ import annotations
import importlib.util
import logging
from typing import TYPE_CHECKING, List, Optional
import torch
import triton.language as tl
from torch.nn.parameter import Parameter
from sglang.srt.layers.moe.utils import get_moe_runner_backend
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
QuantizationConfig,
@@ -31,7 +30,6 @@ from sglang.srt.layers.quantization.base_config import (
)
from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import (
direct_register_custom_op,
get_bool_env_var,
@@ -60,6 +58,7 @@ if is_flashinfer_available():
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
OCP_MX_BLOCK_SIZE = 32
@@ -218,15 +217,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self,
prefix: str,
):
from sglang.srt.managers.schedule_batch import global_server_args_dict
super().__init__()
self.prefix = prefix
self.topk_indices_dtype = None
self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
self.with_bias = False
self.use_flashinfer = global_server_args_dict["enable_flashinfer_mxfp4_moe"]
self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4()
self.triton_kernel_moe_forward = None
self.triton_kernel_moe_with_bias_forward = None
@@ -348,6 +345,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
logger,
f"Shuffling MoE weights for FlashInfer MXFP4 moe kernel (layer: {self.prefix}), it might take a while...",
)
# TODO: these values are hardcoded for now, we need to get them from the model
layer.gemm1_alpha = Parameter(
torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False,
@@ -573,14 +571,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
if self.use_flashinfer:
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
@@ -637,9 +628,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
b1=layer.w13_weight_bias,
b2=layer.w2_weight_bias,
topk_output=topk_output,
activation=activation,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
moe_runner_config=moe_runner_config,
)
else:
return self.triton_kernel_moe_forward(
@@ -647,6 +636,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
)
else:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
@@ -656,13 +646,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
b1=layer.w13_weight_bias,
b2=layer.w2_weight_bias,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
)