[6/N] MoE Refactor: Cleanup MoE-related configs (#8849)
This commit is contained in:
@@ -7,8 +7,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from sglang.srt.layers.moe import should_use_flashinfer_trtllm_moe
|
||||
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
|
||||
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
|
||||
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
@@ -30,10 +30,11 @@ from sglang.srt.layers.quantization.utils import (
|
||||
requantize_with_max_scale,
|
||||
)
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.utils import is_cuda, next_power_of_2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
|
||||
if is_cuda():
|
||||
@@ -422,12 +423,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
|
||||
@@ -436,15 +432,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_output=topk_output,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
moe_runner_config=moe_runner_config,
|
||||
use_fp8_w8a8=True,
|
||||
per_channel_quant=False, # ModelOpt uses per-tensor quantization
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
no_combine=no_combine,
|
||||
)
|
||||
|
||||
|
||||
@@ -741,8 +735,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
@property
|
||||
def enable_flashinfer_cutlass_moe(self) -> bool:
|
||||
from sglang.srt.layers.moe import get_moe_runner_backend
|
||||
|
||||
"""Access the global enable_flashinfer_cutlass_moe setting."""
|
||||
return global_server_args_dict.get("enable_flashinfer_cutlass_moe", False)
|
||||
return get_moe_runner_backend().is_flashinfer_cutlass()
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -1160,21 +1156,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
ep_rank: Optional[int] = None,
|
||||
ep_size: Optional[int] = None,
|
||||
tp_rank: Optional[int] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
assert (
|
||||
moe_runner_config.activation == "silu"
|
||||
), "Only SiLU activation is supported."
|
||||
|
||||
# Check if this is a FlashInferFP4MoE layer that should handle its own forward
|
||||
if hasattr(layer, "gemm1_weights_fp4_shuffled"):
|
||||
@@ -1183,7 +1172,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
if self.enable_flashinfer_cutlass_moe:
|
||||
assert (
|
||||
not apply_router_weight_on_input
|
||||
not moe_runner_config.apply_router_weight_on_input
|
||||
), "apply_router_weight_on_input is not supported for Flashinfer"
|
||||
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
|
||||
# and fp4 quantized weights loaded from the checkpoint
|
||||
@@ -1205,14 +1194,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_blockscale_swizzled.view(torch.int32),
|
||||
layer.g2_alphas,
|
||||
],
|
||||
ep_size=ep_size,
|
||||
ep_rank=ep_rank,
|
||||
tp_size=tp_size,
|
||||
tp_rank=tp_rank,
|
||||
ep_size=layer.moe_ep_size,
|
||||
ep_rank=layer.moe_ep_rank,
|
||||
tp_size=layer.moe_tp_size,
|
||||
tp_rank=layer.moe_tp_rank,
|
||||
tune_max_num_tokens=next_power_of_2(x.shape[0]),
|
||||
)[0]
|
||||
if routed_scaling_factor is not None:
|
||||
output *= routed_scaling_factor
|
||||
if moe_runner_config.routed_scaling_factor is not None:
|
||||
output *= moe_runner_config.routed_scaling_factor
|
||||
return output
|
||||
|
||||
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
||||
@@ -1231,8 +1220,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
params=layer.cutlass_moe_params,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
|
||||
).to(x.dtype)
|
||||
if routed_scaling_factor is not None:
|
||||
output *= routed_scaling_factor
|
||||
if moe_runner_config.routed_scaling_factor is not None:
|
||||
output *= moe_runner_config.routed_scaling_factor
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user