[6/N] MoE Refactor: Cleanup MoE-related configs (#8849)
This commit is contained in:
@@ -16,14 +16,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
import triton.language as tl
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from sglang.srt.layers.moe.utils import get_moe_runner_backend
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
QuantizationConfig,
|
||||
@@ -31,7 +30,6 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
)
|
||||
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
||||
from sglang.srt.layers.utils import is_sm100_supported
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.utils import (
|
||||
direct_register_custom_op,
|
||||
get_bool_env_var,
|
||||
@@ -60,6 +58,7 @@ if is_flashinfer_available():
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
|
||||
OCP_MX_BLOCK_SIZE = 32
|
||||
@@ -218,15 +217,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
self,
|
||||
prefix: str,
|
||||
):
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.prefix = prefix
|
||||
self.topk_indices_dtype = None
|
||||
self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
|
||||
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
|
||||
self.with_bias = False
|
||||
self.use_flashinfer = global_server_args_dict["enable_flashinfer_mxfp4_moe"]
|
||||
self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4()
|
||||
|
||||
self.triton_kernel_moe_forward = None
|
||||
self.triton_kernel_moe_with_bias_forward = None
|
||||
@@ -348,6 +345,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
logger,
|
||||
f"Shuffling MoE weights for FlashInfer MXFP4 moe kernel (layer: {self.prefix}), it might take a while...",
|
||||
)
|
||||
# TODO: these values are hardcoded for now, we need to get them from the model
|
||||
layer.gemm1_alpha = Parameter(
|
||||
torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
|
||||
requires_grad=False,
|
||||
@@ -573,14 +571,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
activation_alpha: Optional[float] = None,
|
||||
swiglu_limit: Optional[float] = None,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
if self.use_flashinfer:
|
||||
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
|
||||
@@ -637,9 +628,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
b1=layer.w13_weight_bias,
|
||||
b2=layer.w2_weight_bias,
|
||||
topk_output=topk_output,
|
||||
activation=activation,
|
||||
activation_alpha=activation_alpha,
|
||||
swiglu_limit=swiglu_limit,
|
||||
moe_runner_config=moe_runner_config,
|
||||
)
|
||||
else:
|
||||
return self.triton_kernel_moe_forward(
|
||||
@@ -647,6 +636,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_output=topk_output,
|
||||
moe_runner_config=moe_runner_config,
|
||||
)
|
||||
else:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
@@ -656,13 +646,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_output=topk_output,
|
||||
moe_runner_config=moe_runner_config,
|
||||
b1=layer.w13_weight_bias,
|
||||
b2=layer.w2_weight_bias,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
no_combine=no_combine,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
activation_alpha=activation_alpha,
|
||||
swiglu_limit=swiglu_limit,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user