[6/N] MoE Refactor: Cleanup MoE-related configs (#8849)
This commit is contained in:
@@ -61,7 +61,6 @@ from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.distributed.parallel_state import destroy_distributed_environment
|
||||
from sglang.srt.entrypoints.engine import _set_envs_and_config
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
|
||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||
from sglang.srt.managers.scheduler import Scheduler
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
@@ -300,11 +299,6 @@ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
|
||||
disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
|
||||
spec_algorithm=SpeculativeAlgorithm.NONE,
|
||||
speculative_num_draft_tokens=None,
|
||||
enable_two_batch_overlap=model_runner.server_args.enable_two_batch_overlap,
|
||||
enable_deepep_moe=MoeA2ABackend(
|
||||
model_runner.server_args.moe_a2a_backend
|
||||
).is_deepep(),
|
||||
deepep_mode=DeepEPMode(model_runner.server_args.deepep_mode),
|
||||
require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
|
||||
disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule,
|
||||
)
|
||||
|
||||
@@ -25,7 +25,6 @@ import torch
|
||||
import torch.distributed
|
||||
|
||||
from sglang.srt.eplb.expert_location import ExpertLocationMetadata
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import Withable, get_bool_env_var
|
||||
@@ -288,14 +287,14 @@ class _SinglePassGatherer(ABC):
|
||||
)
|
||||
|
||||
if server_args.expert_distribution_recorder_mode == "stat_approx":
|
||||
if server_args.moe_a2a_backend is not None and (
|
||||
if server_args.moe_a2a_backend != "none" and (
|
||||
server_args.deepep_mode == "normal"
|
||||
):
|
||||
return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if server_args.moe_a2a_backend is not None:
|
||||
if server_args.moe_a2a_backend != "none":
|
||||
if server_args.deepep_mode == "normal":
|
||||
return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
|
||||
elif server_args.deepep_mode == "low_latency":
|
||||
|
||||
@@ -17,7 +17,7 @@ from enum import Enum, auto
|
||||
from functools import partial
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch.distributed
|
||||
import torch
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
@@ -35,6 +35,7 @@ from sglang.srt.layers.dp_attention import (
|
||||
get_global_dp_buffer,
|
||||
get_local_dp_buffer,
|
||||
)
|
||||
from sglang.srt.layers.moe import get_moe_a2a_backend
|
||||
from sglang.srt.layers.utils import is_sm100_supported
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
@@ -111,7 +112,7 @@ class LayerScatterModes:
|
||||
if context.is_layer_sparse:
|
||||
return (
|
||||
ScatterMode.SCATTERED
|
||||
if not global_server_args_dict["moe_a2a_backend"].is_standard()
|
||||
if not get_moe_a2a_backend().is_none()
|
||||
else ScatterMode.FULL
|
||||
)
|
||||
else:
|
||||
|
||||
29
python/sglang/srt/layers/moe/__init__.py
Normal file
29
python/sglang/srt/layers/moe/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.utils import (
|
||||
DeepEPMode,
|
||||
MoeA2ABackend,
|
||||
MoeRunnerBackend,
|
||||
get_deepep_config,
|
||||
get_deepep_mode,
|
||||
get_moe_a2a_backend,
|
||||
get_moe_runner_backend,
|
||||
get_tbo_token_distribution_threshold,
|
||||
initialize_moe_config,
|
||||
is_tbo_enabled,
|
||||
should_use_flashinfer_trtllm_moe,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DeepEPMode",
|
||||
"MoeA2ABackend",
|
||||
"MoeRunnerConfig",
|
||||
"MoeRunnerBackend",
|
||||
"initialize_moe_config",
|
||||
"get_moe_a2a_backend",
|
||||
"get_moe_runner_backend",
|
||||
"get_deepep_mode",
|
||||
"should_use_flashinfer_trtllm_moe",
|
||||
"is_tbo_enabled",
|
||||
"get_tbo_token_distribution_threshold",
|
||||
"get_deepep_config",
|
||||
]
|
||||
@@ -1,11 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
||||
from sglang.srt.layers.moe import (
|
||||
get_deepep_mode,
|
||||
get_moe_a2a_backend,
|
||||
get_moe_runner_backend,
|
||||
should_use_flashinfer_trtllm_moe,
|
||||
)
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
ep_gather,
|
||||
ep_scatter,
|
||||
@@ -16,14 +22,9 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
)
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
from sglang.srt.layers.moe.utils import DeepEPMode, should_use_flashinfer_trtllm_moe
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.fp8 import (
|
||||
Fp8Config,
|
||||
Fp8MoEMethod,
|
||||
get_tile_tokens_dim,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
is_fp8_fnuz,
|
||||
sglang_per_token_group_quant_fp8,
|
||||
@@ -89,12 +90,11 @@ class EPMoE(FusedMoE):
|
||||
num_fused_shared_experts: int = 0,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
activation: str = "silu",
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
activation_alpha: Optional[float] = None,
|
||||
swiglu_limit: Optional[float] = None,
|
||||
gemm1_alpha: Optional[float] = None,
|
||||
gemm1_clamp_limit: Optional[float] = None,
|
||||
with_bias: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
@@ -106,13 +106,12 @@ class EPMoE(FusedMoE):
|
||||
top_k=top_k,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
tp_size=tp_size,
|
||||
prefix=prefix,
|
||||
activation=activation,
|
||||
# apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
activation_alpha=activation_alpha,
|
||||
swiglu_limit=swiglu_limit,
|
||||
gemm1_alpha=gemm1_alpha,
|
||||
gemm1_clamp_limit=gemm1_clamp_limit,
|
||||
with_bias=with_bias,
|
||||
)
|
||||
|
||||
@@ -163,7 +162,8 @@ class EPMoE(FusedMoE):
|
||||
)
|
||||
|
||||
assert self.quant_method is not None
|
||||
assert self.activation == "silu"
|
||||
assert self.moe_runner_config.activation == "silu"
|
||||
|
||||
hidden_states_shape = hidden_states.shape
|
||||
hidden_states_dtype = hidden_states.dtype
|
||||
hidden_states_device = hidden_states.device
|
||||
@@ -327,8 +327,8 @@ class EPMoE(FusedMoE):
|
||||
m_max * self.start_expert_id,
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
if self.routed_scaling_factor is not None:
|
||||
output *= self.routed_scaling_factor
|
||||
if self.moe_runner_config.routed_scaling_factor is not None:
|
||||
output *= self.moe_runner_config.routed_scaling_factor
|
||||
return output
|
||||
|
||||
|
||||
@@ -349,11 +349,9 @@ class DeepEPMoE(EPMoE):
|
||||
num_fused_shared_experts: int = 0,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
activation: str = "silu",
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
deepep_mode: DeepEPMode = DeepEPMode.AUTO,
|
||||
):
|
||||
super().__init__(
|
||||
num_experts=num_experts,
|
||||
@@ -364,12 +362,11 @@ class DeepEPMoE(EPMoE):
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
tp_size=tp_size,
|
||||
prefix=prefix,
|
||||
activation=activation,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
self.deepep_mode = deepep_mode
|
||||
self.deepep_mode = get_deepep_mode()
|
||||
|
||||
# TODO: move to the beginning of the file
|
||||
from sglang.srt.distributed.parallel_state import get_tp_group
|
||||
@@ -383,7 +380,7 @@ class DeepEPMoE(EPMoE):
|
||||
num_local_experts=self.num_local_experts,
|
||||
hidden_size=hidden_size,
|
||||
params_dtype=params_dtype,
|
||||
deepep_mode=deepep_mode,
|
||||
deepep_mode=self.deepep_mode,
|
||||
async_finish=True, # TODO
|
||||
return_recv_hook=True,
|
||||
)
|
||||
@@ -458,15 +455,19 @@ class DeepEPMoE(EPMoE):
|
||||
)
|
||||
|
||||
def moe_impl(self, dispatch_output: DispatchOutput):
|
||||
from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
|
||||
|
||||
if _use_aiter:
|
||||
assert DispatchOutputChecker.format_is_deepep(dispatch_output)
|
||||
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
||||
return self.forward_aiter(dispatch_output)
|
||||
if _is_npu:
|
||||
assert DispatchOutputChecker.format_is_ascent_ll(dispatch_output)
|
||||
return self.forward_npu(dispatch_output)
|
||||
if dispatch_output.format.is_deepep_normal():
|
||||
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
|
||||
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
||||
return self.forward_deepgemm_contiguous(dispatch_output)
|
||||
elif dispatch_output.format.is_deepep_ll():
|
||||
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
|
||||
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
||||
return self.forward_deepgemm_masked(dispatch_output)
|
||||
else:
|
||||
@@ -490,7 +491,7 @@ class DeepEPMoE(EPMoE):
|
||||
|
||||
def forward_aiter(
|
||||
self,
|
||||
dispatch_output: DeepEPNormalOutput,
|
||||
dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
|
||||
):
|
||||
hidden_states, topk_idx, topk_weights = (
|
||||
dispatch_output.hidden_states,
|
||||
@@ -516,7 +517,7 @@ class DeepEPMoE(EPMoE):
|
||||
quant_type=QuantType.per_128x128,
|
||||
activation=(
|
||||
ActivationType.Silu
|
||||
if self.activation == "silu"
|
||||
if self.moe_runner_config.activation == "silu"
|
||||
else ActivationType.Gelu
|
||||
),
|
||||
expert_mask=self.expert_mask,
|
||||
@@ -531,7 +532,7 @@ class DeepEPMoE(EPMoE):
|
||||
)
|
||||
hidden_states_fp8, hidden_states_scale = hidden_states_fp8
|
||||
assert self.quant_method is not None
|
||||
assert self.activation == "silu"
|
||||
assert self.moe_runner_config.activation == "silu"
|
||||
if num_recv_tokens_per_expert is None:
|
||||
return hidden_states_fp8.bfloat16()
|
||||
all_tokens = sum(num_recv_tokens_per_expert)
|
||||
@@ -652,7 +653,7 @@ class DeepEPMoE(EPMoE):
|
||||
):
|
||||
hidden_states_fp8, _, _, masked_m, expected_m = dispatch_output
|
||||
assert self.quant_method is not None
|
||||
assert self.activation == "silu"
|
||||
assert self.moe_runner_config.activation == "silu"
|
||||
|
||||
# GroupGemm-0
|
||||
num_groups, m, k = hidden_states_fp8[0].size()
|
||||
@@ -783,12 +784,12 @@ class DeepEPMoE(EPMoE):
|
||||
|
||||
|
||||
def get_moe_impl_class():
|
||||
if global_server_args_dict["moe_a2a_backend"].is_deepep():
|
||||
if get_moe_a2a_backend().is_deepep():
|
||||
return DeepEPMoE
|
||||
|
||||
# NEW: Direct FP4 detection (bypasses EP requirements)
|
||||
# Check for FP4 quantization with TRTLLM flag, regardless of EP
|
||||
if global_server_args_dict.get("enable_flashinfer_trtllm_moe", False):
|
||||
if get_moe_runner_backend().is_flashinfer_trtllm():
|
||||
try:
|
||||
# Check the quantization argument directly
|
||||
quantization = global_server_args_dict.get("quantization")
|
||||
@@ -803,7 +804,7 @@ def get_moe_impl_class():
|
||||
|
||||
if should_use_flashinfer_trtllm_moe():
|
||||
return FlashInferFusedMoE
|
||||
if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
|
||||
if get_moe_runner_backend().is_flashinfer_cutlass():
|
||||
return FusedMoE
|
||||
if get_moe_expert_parallel_world_size() > 1:
|
||||
return EPMoE
|
||||
|
||||
@@ -3,28 +3,22 @@ Torch-native implementation for FusedMoE. This is used for torch.compile.
|
||||
It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204
|
||||
"""
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
||||
|
||||
|
||||
def fused_moe_forward_native(
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
topk_output: StandardTopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
if moe_runner_config.apply_router_weight_on_input:
|
||||
raise NotImplementedError()
|
||||
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
@@ -33,12 +27,12 @@ def fused_moe_forward_native(
|
||||
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
|
||||
w2_weights = layer.w2_weight[topk_ids]
|
||||
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
|
||||
if activation == "silu":
|
||||
if moe_runner_config.activation == "silu":
|
||||
x1 = F.silu(x1)
|
||||
elif activation == "gelu":
|
||||
elif moe_runner_config.activation == "gelu":
|
||||
x1 = F.gelu(x1)
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation: {activation=}")
|
||||
raise ValueError(f"Unsupported activation: {moe_runner_config.activation=}")
|
||||
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
||||
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
||||
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
|
||||
@@ -47,16 +41,11 @@ def fused_moe_forward_native(
|
||||
def moe_forward_native(
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
topk_output: StandardTopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
if moe_runner_config.apply_router_weight_on_input:
|
||||
raise NotImplementedError()
|
||||
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
@@ -72,12 +61,12 @@ def moe_forward_native(
|
||||
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
||||
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
||||
|
||||
if activation == "silu":
|
||||
if moe_runner_config.activation == "silu":
|
||||
act = SiluAndMul()
|
||||
elif activation == "gelu":
|
||||
elif moe_runner_config.activation == "gelu":
|
||||
act = GeluAndMul()
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation: {activation=}")
|
||||
raise ValueError(f"Unsupported activation: {moe_runner_config.activation=}")
|
||||
|
||||
outputs = []
|
||||
start_idx = 0
|
||||
|
||||
@@ -2,17 +2,20 @@
|
||||
|
||||
"""Fused MoE kernel."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
per_token_group_quant_fp8,
|
||||
scaled_fp8_quant,
|
||||
@@ -1025,8 +1028,8 @@ def inplace_fused_experts(
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
activation_alpha: Optional[float] = None,
|
||||
swiglu_limit: Optional[float] = None,
|
||||
gemm1_alpha: Optional[float] = None,
|
||||
gemm1_limit: Optional[float] = None,
|
||||
) -> None:
|
||||
fused_experts_impl(
|
||||
hidden_states,
|
||||
@@ -1053,8 +1056,8 @@ def inplace_fused_experts(
|
||||
block_shape,
|
||||
False,
|
||||
routed_scaling_factor,
|
||||
activation_alpha,
|
||||
swiglu_limit,
|
||||
gemm1_alpha,
|
||||
gemm1_limit,
|
||||
)
|
||||
|
||||
|
||||
@@ -1081,8 +1084,8 @@ def inplace_fused_experts_fake(
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
activation_alpha: Optional[float] = None,
|
||||
swiglu_limit: Optional[float] = None,
|
||||
gemm1_alpha: Optional[float] = None,
|
||||
gemm1_limit: Optional[float] = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@@ -1119,8 +1122,8 @@ def outplace_fused_experts(
|
||||
block_shape: Optional[List[int]] = None,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
activation_alpha: Optional[float] = None,
|
||||
swiglu_limit: Optional[float] = None,
|
||||
gemm1_alpha: Optional[float] = None,
|
||||
gemm1_limit: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
return fused_experts_impl(
|
||||
hidden_states,
|
||||
@@ -1147,8 +1150,8 @@ def outplace_fused_experts(
|
||||
block_shape,
|
||||
no_combine=no_combine,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
activation_alpha=activation_alpha,
|
||||
swiglu_limit=swiglu_limit,
|
||||
gemm1_alpha=gemm1_alpha,
|
||||
gemm1_limit=gemm1_limit,
|
||||
)
|
||||
|
||||
|
||||
@@ -1176,8 +1179,8 @@ def outplace_fused_experts_fake(
|
||||
block_shape: Optional[List[int]] = None,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
activation_alpha: Optional[float] = None,
|
||||
swiglu_limit: Optional[float] = None,
|
||||
gemm1_alpha: Optional[float] = None,
|
||||
gemm1_limit: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
@@ -1194,12 +1197,10 @@ def fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
topk_output: StandardTopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
b1: Optional[torch.Tensor] = None,
|
||||
b2: Optional[torch.Tensor] = None,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
@@ -1212,14 +1213,10 @@ def fused_experts(
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
activation_alpha: Optional[float] = None,
|
||||
swiglu_limit: Optional[float] = None,
|
||||
):
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
if inplace:
|
||||
assert not no_combine, "no combine + inplace makes no sense"
|
||||
if moe_runner_config.inplace:
|
||||
assert not moe_runner_config.no_combine, "no combine + inplace makes no sense"
|
||||
torch.ops.sglang.inplace_fused_experts(
|
||||
hidden_states,
|
||||
w1,
|
||||
@@ -1228,8 +1225,8 @@ def fused_experts(
|
||||
topk_ids,
|
||||
b1,
|
||||
b2,
|
||||
activation,
|
||||
apply_router_weight_on_input,
|
||||
moe_runner_config.activation,
|
||||
moe_runner_config.apply_router_weight_on_input,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
@@ -1242,9 +1239,9 @@ def fused_experts(
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
block_shape,
|
||||
routed_scaling_factor,
|
||||
activation_alpha,
|
||||
swiglu_limit,
|
||||
moe_runner_config.routed_scaling_factor,
|
||||
moe_runner_config.gemm1_alpha,
|
||||
moe_runner_config.gemm1_clamp_limit,
|
||||
)
|
||||
return hidden_states
|
||||
else:
|
||||
@@ -1256,8 +1253,8 @@ def fused_experts(
|
||||
topk_ids,
|
||||
b1,
|
||||
b2,
|
||||
activation,
|
||||
apply_router_weight_on_input,
|
||||
moe_runner_config.activation,
|
||||
moe_runner_config.apply_router_weight_on_input,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
@@ -1270,10 +1267,10 @@ def fused_experts(
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
block_shape,
|
||||
no_combine=no_combine,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
activation_alpha=activation_alpha,
|
||||
swiglu_limit=swiglu_limit,
|
||||
no_combine=moe_runner_config.no_combine,
|
||||
routed_scaling_factor=moe_runner_config.routed_scaling_factor,
|
||||
gemm1_alpha=moe_runner_config.gemm1_alpha,
|
||||
gemm1_limit=moe_runner_config.gemm1_clamp_limit,
|
||||
)
|
||||
|
||||
|
||||
@@ -1370,11 +1367,11 @@ def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor):
|
||||
|
||||
|
||||
@torch.compile
|
||||
def swiglu_with_alpha_and_limit(x, alpha, limit):
|
||||
def swiglu_with_alpha_and_limit(x, gemm1_alpha, gemm1_limit):
|
||||
gate, up = x[..., ::2], x[..., 1::2]
|
||||
gate = gate.clamp(min=None, max=limit)
|
||||
up = up.clamp(min=-limit, max=limit)
|
||||
return gate * torch.sigmoid(gate * alpha) * (up + 1)
|
||||
gate = gate.clamp(min=None, max=gemm1_limit)
|
||||
up = up.clamp(min=-gemm1_limit, max=gemm1_limit)
|
||||
return gate * torch.sigmoid(gate * gemm1_alpha) * (up + 1)
|
||||
|
||||
|
||||
def fused_experts_impl(
|
||||
@@ -1402,8 +1399,8 @@ def fused_experts_impl(
|
||||
block_shape: Optional[List[int]] = None,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
activation_alpha: Optional[float] = None,
|
||||
swiglu_limit: Optional[float] = None,
|
||||
gemm1_alpha: Optional[float] = None,
|
||||
gemm1_limit: Optional[float] = None,
|
||||
):
|
||||
padded_size = padding_size
|
||||
if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
|
||||
@@ -1533,12 +1530,12 @@ def fused_experts_impl(
|
||||
block_shape=block_shape,
|
||||
)
|
||||
if activation == "silu":
|
||||
if activation_alpha is not None:
|
||||
assert swiglu_limit is not None
|
||||
if gemm1_alpha is not None:
|
||||
assert gemm1_limit is not None
|
||||
intermediate_cache2 = swiglu_with_alpha_and_limit(
|
||||
intermediate_cache1.view(-1, N),
|
||||
activation_alpha,
|
||||
swiglu_limit,
|
||||
gemm1_alpha,
|
||||
gemm1_limit,
|
||||
)
|
||||
elif _is_cuda:
|
||||
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
||||
@@ -1547,10 +1544,8 @@ def fused_experts_impl(
|
||||
intermediate_cache2, intermediate_cache1.view(-1, N)
|
||||
)
|
||||
elif activation == "gelu":
|
||||
assert (
|
||||
activation_alpha is None
|
||||
), "activation_alpha is not supported for gelu"
|
||||
assert swiglu_limit is None, "swiglu_limit is not supported for gelu"
|
||||
assert gemm1_alpha is None, "gemm1_alpha is not supported for gelu"
|
||||
assert gemm1_limit is None, "gemm1_limit is not supported for gelu"
|
||||
if _is_cuda:
|
||||
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
||||
else:
|
||||
@@ -1641,12 +1636,10 @@ def fused_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
topk_output: StandardTopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig = MoeRunnerConfig(),
|
||||
b1: Optional[torch.Tensor] = None,
|
||||
b2: Optional[torch.Tensor] = None,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
@@ -1659,10 +1652,6 @@ def fused_moe(
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
activation_alpha: Optional[float] = None,
|
||||
swiglu_limit: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||
@@ -1672,11 +1661,10 @@ def fused_moe(
|
||||
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
||||
- w1 (torch.Tensor): The first set of expert weights.
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- topk_output (TopKOutput): The top-k output of the experts.
|
||||
- topk_output (StandardTopKOutput): The top-k output of the experts.
|
||||
- moe_runner_config (MoeRunnerConfig): The configuration for the MoE runner.
|
||||
- b1 (Optional[torch.Tensor]): Optional bias for w1.
|
||||
- b2 (Optional[torch.Tensor]): Optional bias for w2.
|
||||
- inplace (bool): If True, perform the operation in-place.
|
||||
Defaults to False.
|
||||
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
|
||||
products for w1 and w2. Defaults to False.
|
||||
- use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner
|
||||
@@ -1696,9 +1684,9 @@ def fused_moe(
|
||||
a2.
|
||||
- block_shape: (Optional[List[int]]): Optional block size for block-wise
|
||||
quantization.
|
||||
- activation_alpha (Optional[float]): Optional alpha for the activation
|
||||
- gemm1_alpha (Optional[float]): Optional gemm1_alpha for the activation
|
||||
function.
|
||||
- swiglu_limit (Optional[float]): Optional limit for the swiglu activation
|
||||
- gemm1_limit (Optional[float]): Optional gemm1_limit for the swiglu activation
|
||||
function.
|
||||
|
||||
Returns:
|
||||
@@ -1710,11 +1698,9 @@ def fused_moe(
|
||||
w1,
|
||||
w2,
|
||||
topk_output,
|
||||
moe_runner_config=moe_runner_config,
|
||||
b1=b1,
|
||||
b2=b2,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
@@ -1727,8 +1713,4 @@ def fused_moe(
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape,
|
||||
no_combine=no_combine,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
activation_alpha=activation_alpha,
|
||||
swiglu_limit=swiglu_limit,
|
||||
)
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
|
||||
|
||||
import datetime
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
@@ -22,8 +18,12 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
||||
use_symmetric_memory,
|
||||
)
|
||||
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
||||
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
||||
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
|
||||
from sglang.srt.layers.moe import (
|
||||
MoeRunnerConfig,
|
||||
get_moe_runner_backend,
|
||||
should_use_flashinfer_trtllm_moe,
|
||||
)
|
||||
from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
@@ -126,7 +126,6 @@ class FusedMoE(torch.nn.Module):
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = False,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
@@ -134,9 +133,8 @@ class FusedMoE(torch.nn.Module):
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
enable_flashinfer_cutlass_moe: Optional[bool] = False,
|
||||
activation_alpha: Optional[float] = None,
|
||||
swiglu_limit: Optional[float] = None,
|
||||
gemm1_alpha: Optional[float] = None,
|
||||
gemm1_clamp_limit: Optional[float] = None,
|
||||
use_weight_loader_fused: bool = False,
|
||||
with_bias=False,
|
||||
):
|
||||
@@ -153,9 +151,17 @@ class FusedMoE(torch.nn.Module):
|
||||
self.expert_map_cpu = None
|
||||
self.expert_map_gpu = None
|
||||
|
||||
# For activation
|
||||
self.activation_alpha = activation_alpha
|
||||
self.swiglu_limit = swiglu_limit
|
||||
self.moe_runner_config = MoeRunnerConfig(
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
inplace=inplace,
|
||||
no_combine=no_combine,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
gemm1_alpha=gemm1_alpha,
|
||||
gemm1_clamp_limit=gemm1_clamp_limit,
|
||||
)
|
||||
|
||||
enable_flashinfer_cutlass_moe = get_moe_runner_backend().is_flashinfer_cutlass()
|
||||
|
||||
if enable_flashinfer_cutlass_moe and quant_config is None:
|
||||
logger.warning("Disable flashinfer MoE when quantization config is None.")
|
||||
@@ -184,20 +190,12 @@ class FusedMoE(torch.nn.Module):
|
||||
* self.num_local_experts
|
||||
] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
|
||||
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
assert intermediate_size % self.moe_tp_size == 0
|
||||
self.intermediate_size_per_partition = intermediate_size // self.moe_tp_size
|
||||
self.reduce_results = reduce_results
|
||||
self.activation = activation
|
||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||
self.use_presharded_weights = use_presharded_weights
|
||||
self.inplace = inplace
|
||||
self.no_combine = no_combine
|
||||
|
||||
self.use_triton_kernels = (
|
||||
not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"]
|
||||
)
|
||||
|
||||
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
|
||||
if quant_config is None:
|
||||
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
|
||||
self.use_triton_kernels
|
||||
@@ -207,14 +205,12 @@ class FusedMoE(torch.nn.Module):
|
||||
assert self.quant_method is not None
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.use_enable_flashinfer_mxfp4_moe = global_server_args_dict.get(
|
||||
"enable_flashinfer_mxfp4_moe", False
|
||||
)
|
||||
self.use_flashinfer_mxfp4_moe = get_moe_runner_backend().is_flashinfer_mxfp4()
|
||||
# TODO maybe we should remove this `if`, since `Mxfp4MoEMethod` does another round-up logic
|
||||
if (
|
||||
self.quant_config is not None
|
||||
and self.quant_config.get_name() == "mxfp4"
|
||||
and self.use_enable_flashinfer_mxfp4_moe
|
||||
and self.use_flashinfer_mxfp4_moe
|
||||
):
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
self.quant_method.create_weights(
|
||||
@@ -794,7 +790,7 @@ class FusedMoE(torch.nn.Module):
|
||||
f"Unsupported weight_name {weight_name} for FusedMoE weight_loader_fused. Nothing is loaded."
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
|
||||
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
||||
origin_hidden_states_dim = hidden_states.shape[-1]
|
||||
assert self.quant_method is not None
|
||||
|
||||
@@ -803,40 +799,22 @@ class FusedMoE(torch.nn.Module):
|
||||
# If we are in EP mode, we need to move the expert map to GPU.
|
||||
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
|
||||
|
||||
if self.expert_map_gpu is not None and isinstance(
|
||||
topk_output, StandardTopKOutput
|
||||
):
|
||||
topk_output = topk_output._replace(
|
||||
topk_ids=self.expert_map_gpu[topk_output.topk_ids]
|
||||
)
|
||||
if self.expert_map_gpu is not None:
|
||||
if TopKOutputChecker.format_is_standard(topk_output):
|
||||
topk_output = topk_output._replace(
|
||||
topk_ids=self.expert_map_gpu[topk_output.topk_ids]
|
||||
)
|
||||
elif TopKOutputChecker.format_is_triton_kernel(topk_output):
|
||||
raise NotImplementedError()
|
||||
|
||||
# Matrix multiply.
|
||||
with use_symmetric_memory(get_tp_group()) as sm:
|
||||
kwargs = {}
|
||||
if self.activation_alpha is not None:
|
||||
kwargs["activation_alpha"] = self.activation_alpha
|
||||
if self.swiglu_limit is not None:
|
||||
kwargs["swiglu_limit"] = self.swiglu_limit
|
||||
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
topk_output=topk_output,
|
||||
activation=self.activation,
|
||||
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
**(
|
||||
dict(
|
||||
tp_rank=self.moe_tp_rank,
|
||||
tp_size=self.moe_tp_size,
|
||||
ep_rank=self.moe_ep_rank,
|
||||
ep_size=self.moe_ep_size,
|
||||
)
|
||||
if self.quant_method.__class__.__name__
|
||||
== "ModelOptNvFp4FusedMoEMethod"
|
||||
else {}
|
||||
),
|
||||
**kwargs,
|
||||
moe_runner_config=self.moe_runner_config,
|
||||
)
|
||||
sm.tag(final_hidden_states)
|
||||
|
||||
@@ -944,24 +922,10 @@ class FusedMoE(torch.nn.Module):
|
||||
|
||||
class FlashInferFusedMoE(FusedMoE):
|
||||
def __init__(self, *args, **kwargs):
|
||||
renormalize = kwargs.pop("renormalize", True)
|
||||
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
|
||||
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
|
||||
num_expert_group = kwargs.pop("num_expert_group", None)
|
||||
topk_group = kwargs.pop("topk_group", None)
|
||||
correction_bias = kwargs.pop("correction_bias", None)
|
||||
super().__init__(*args, **kwargs)
|
||||
self.renormalize = renormalize
|
||||
self.num_fused_shared_experts = num_fused_shared_experts
|
||||
self.use_grouped_topk = use_grouped_topk
|
||||
if self.use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
self.num_expert_group = num_expert_group
|
||||
self.topk_group = topk_group
|
||||
self.correction_bias = correction_bias
|
||||
self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, topk_output: tuple):
|
||||
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
||||
assert self.use_flashinfer_trtllm_moe
|
||||
assert (
|
||||
self.activation == "silu"
|
||||
@@ -974,20 +938,14 @@ class FlashInferFusedMoE(FusedMoE):
|
||||
self.num_fused_shared_experts == 0
|
||||
), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
|
||||
|
||||
# TRTLLM mode expects (TopK_config, router_logits) tuple
|
||||
if not isinstance(topk_output, tuple) or len(topk_output) != 2:
|
||||
raise ValueError(
|
||||
f"FlashInferFusedMoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
|
||||
)
|
||||
_, router_logits = topk_output
|
||||
assert TopKOutputChecker.format_is_bypassed(topk_output)
|
||||
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply_with_router_logits(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
activation=self.activation,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
topk_output=topk_output,
|
||||
moe_runner_config=self.moe_runner_config,
|
||||
)
|
||||
|
||||
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
|
||||
@@ -1000,28 +958,8 @@ class FlashInferFP4MoE(FusedMoE):
|
||||
"""FP4 TRTLLM MoE implementation using FlashInfer."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Extract DeepSeek-specific parameters
|
||||
renormalize = kwargs.pop("renormalize", True)
|
||||
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
|
||||
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
|
||||
num_expert_group = kwargs.pop("num_expert_group", None)
|
||||
topk_group = kwargs.pop("topk_group", None)
|
||||
correction_bias = kwargs.pop("correction_bias", None)
|
||||
|
||||
# Extract additional TopK parameters that were previously extracted in forward
|
||||
routed_scaling_factor = kwargs.pop("routed_scaling_factor", None)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Store DeepSeek parameters
|
||||
self.renormalize = renormalize
|
||||
self.num_fused_shared_experts = num_fused_shared_experts
|
||||
self.use_grouped_topk = use_grouped_topk
|
||||
self.num_expert_group = num_expert_group
|
||||
self.topk_group = topk_group
|
||||
self.correction_bias = correction_bias
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Helper: quantize hidden states to FP4 each forward pass
|
||||
# ---------------------------------------------------------------------
|
||||
@@ -1052,21 +990,17 @@ class FlashInferFP4MoE(FusedMoE):
|
||||
|
||||
return hs_fp4, hs_sf
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, topk_output):
|
||||
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
||||
"""Forward pass using FP4 TRTLLM kernel.
|
||||
|
||||
Args:
|
||||
hidden_states: Input tensor
|
||||
topk_output: Should be tuple of (TopK_config, router_logits) for TRTLLM mode
|
||||
topk_output: TopKOutput object with Bypassed format
|
||||
"""
|
||||
assert TopKOutputChecker.format_is_bypassed(topk_output)
|
||||
|
||||
# TRTLLM mode expects (TopK_config, router_logits) tuple
|
||||
if not isinstance(topk_output, tuple) or len(topk_output) != 2:
|
||||
raise ValueError(
|
||||
f"FlashInferFP4MoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
|
||||
)
|
||||
|
||||
_, router_logits = topk_output
|
||||
router_logits = topk_output.router_logits
|
||||
topk_config = topk_output.topk_config
|
||||
|
||||
hs_fp4, hs_scale_linear = self._quantize_hidden_states_fp4(hidden_states)
|
||||
|
||||
@@ -1074,7 +1008,7 @@ class FlashInferFP4MoE(FusedMoE):
|
||||
|
||||
result = trtllm_fp4_block_scale_moe(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=self.correction_bias.to(hidden_states.dtype),
|
||||
routing_bias=topk_config.correction_bias.to(hidden_states.dtype),
|
||||
hidden_states=hs_fp4,
|
||||
hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).flatten(),
|
||||
gemm1_weights=self.gemm1_weights_fp4_shuffled.data,
|
||||
@@ -1094,15 +1028,15 @@ class FlashInferFP4MoE(FusedMoE):
|
||||
output1_scale_gate_scalar=self.g1_alphas.data,
|
||||
output2_scale_scalar=self.g2_alphas.data,
|
||||
num_experts=self.num_experts,
|
||||
top_k=self.top_k,
|
||||
n_group=self.num_expert_group,
|
||||
topk_group=self.topk_group,
|
||||
top_k=topk_config.top_k,
|
||||
n_group=topk_config.num_expert_group,
|
||||
topk_group=topk_config.topk_group,
|
||||
intermediate_size=self.intermediate_size_per_partition,
|
||||
local_expert_offset=self.moe_ep_rank * self.num_local_experts,
|
||||
local_num_experts=self.num_local_experts,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
routed_scaling_factor=self.moe_runner_config.routed_scaling_factor,
|
||||
tile_tokens_dim=_get_tile_tokens_dim(
|
||||
hidden_states.shape[0], self.top_k, self.num_local_experts
|
||||
hidden_states.shape[0], topk_config.top_k, self.num_local_experts
|
||||
),
|
||||
routing_method_type=RoutingMethodType.DeepSeekV3,
|
||||
do_finalize=True,
|
||||
|
||||
@@ -18,6 +18,7 @@ from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
|
||||
from triton_kernels.swiglu import swiglu_fn
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
|
||||
|
||||
@@ -55,8 +56,7 @@ def triton_kernel_moe_forward(
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
@@ -69,7 +69,10 @@ def triton_kernel_moe_forward(
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
assert topk_output.format.is_triton_kernel()
|
||||
from sglang.srt.layers.moe.topk import TopKOutputChecker
|
||||
|
||||
assert TopKOutputChecker.format_is_triton_kernel(topk_output)
|
||||
|
||||
routing_data, gather_idx, scatter_idx = topk_output
|
||||
|
||||
return triton_kernel_fused_experts(
|
||||
@@ -79,8 +82,8 @@ def triton_kernel_moe_forward(
|
||||
routing_data,
|
||||
gather_idx,
|
||||
scatter_idx,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
inplace=False, # triton kernel doesn't support inplace
|
||||
activation=moe_runner_config.activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
per_channel_quant=per_channel_quant,
|
||||
@@ -192,8 +195,7 @@ def triton_kernel_moe_with_bias_forward(
|
||||
w2_pcg,
|
||||
b2: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
use_fp8_w8a8: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
@@ -203,10 +205,11 @@ def triton_kernel_moe_with_bias_forward(
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
activation_alpha: Optional[float] = None,
|
||||
swiglu_limit: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
assert topk_output.format.is_triton_kernel()
|
||||
from sglang.srt.layers.moe.topk import TopKOutputChecker
|
||||
|
||||
assert TopKOutputChecker.format_is_triton_kernel(topk_output)
|
||||
|
||||
routing_data, gather_idx, scatter_idx = topk_output
|
||||
|
||||
return triton_kernel_fused_experts_with_bias(
|
||||
@@ -220,8 +223,8 @@ def triton_kernel_moe_with_bias_forward(
|
||||
routing_data=routing_data,
|
||||
gather_indx=gather_idx,
|
||||
scatter_indx=scatter_idx,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
inplace=False, # triton kernel doesn't support inplace
|
||||
activation=moe_runner_config.activation,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
per_channel_quant=per_channel_quant,
|
||||
global_num_experts=global_num_experts,
|
||||
@@ -231,8 +234,8 @@ def triton_kernel_moe_with_bias_forward(
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape,
|
||||
activation_alpha=activation_alpha,
|
||||
swiglu_limit=swiglu_limit,
|
||||
gemm1_alpha=moe_runner_config.gemm1_alpha,
|
||||
gemm1_clamp_limit=moe_runner_config.gemm1_clamp_limit,
|
||||
)
|
||||
|
||||
|
||||
@@ -258,10 +261,9 @@ def triton_kernel_fused_experts_with_bias(
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
activation_alpha: Optional[float] = None,
|
||||
swiglu_limit: Optional[int] = None,
|
||||
gemm1_alpha: Optional[float] = None,
|
||||
gemm1_clamp_limit: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
# print(f"here in triton moe with bias", b1.shape, b1.dtype, b2.shape, b2.dtype)
|
||||
assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported"
|
||||
assert per_channel_quant == False, "per_channel_quant is not supported"
|
||||
assert expert_map == None, "expert_map is not supported"
|
||||
@@ -307,7 +309,7 @@ def triton_kernel_fused_experts_with_bias(
|
||||
|
||||
act = FusedActivation(
|
||||
FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")),
|
||||
(activation_alpha, swiglu_limit),
|
||||
(gemm1_alpha, gemm1_clamp_limit),
|
||||
2,
|
||||
)
|
||||
|
||||
|
||||
3
python/sglang/srt/layers/moe/moe_runner/__init__.py
Normal file
3
python/sglang/srt/layers/moe/moe_runner/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
|
||||
|
||||
__all__ = ["MoeRunnerConfig"]
|
||||
13
python/sglang/srt/layers/moe/moe_runner/base.py
Normal file
13
python/sglang/srt/layers/moe/moe_runner/base.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class MoeRunnerConfig:
|
||||
activation: str = "silu"
|
||||
apply_router_weight_on_input: bool = False
|
||||
inplace: bool = True
|
||||
no_combine: bool = False
|
||||
routed_scaling_factor: Optional[float] = None
|
||||
gemm1_alpha: Optional[float] = None
|
||||
gemm1_clamp_limit: Optional[float] = None
|
||||
@@ -2,20 +2,26 @@ from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
|
||||
BaseDispatcher,
|
||||
BaseDispatcherConfig,
|
||||
DispatchOutput,
|
||||
DispatchOutputChecker,
|
||||
DispatchOutputFormat,
|
||||
)
|
||||
from sglang.srt.layers.moe.token_dispatcher.deepep import (
|
||||
AscendDeepEPLLOutput,
|
||||
DeepEPConfig,
|
||||
DeepEPDispatcher,
|
||||
DeepEPLLOutput,
|
||||
DeepEPNormalOutput,
|
||||
)
|
||||
from sglang.srt.layers.moe.token_dispatcher.standard import StandardDispatchOutput
|
||||
|
||||
__all__ = [
|
||||
"AscendDeepEPLLOutput",
|
||||
"BaseDispatcher",
|
||||
"BaseDispatcherConfig",
|
||||
"DispatchOutput",
|
||||
"DispatchOutputFormat",
|
||||
"DispatchOutputChecker",
|
||||
"StandardDispatchOutput",
|
||||
"DeepEPConfig",
|
||||
"DeepEPDispatcher",
|
||||
"DeepEPNormalOutput",
|
||||
|
||||
@@ -2,35 +2,76 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum, auto
|
||||
from typing import Protocol, runtime_checkable
|
||||
from typing import TYPE_CHECKING, Protocol, TypeGuard, Union, runtime_checkable
|
||||
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
AscendDeepEPLLOutput,
|
||||
DeepEPLLOutput,
|
||||
DeepEPNormalOutput,
|
||||
StandardDispatchOutput,
|
||||
)
|
||||
|
||||
class MoEA2ABackend(Enum):
|
||||
none = "none"
|
||||
deepep = "deepep"
|
||||
|
||||
def is_none(self):
|
||||
return self == MoEA2ABackend.none
|
||||
class DispatchOutputChecker:
|
||||
|
||||
def is_deepep(self):
|
||||
return self == MoEA2ABackend.deepep
|
||||
@staticmethod
|
||||
def format_is_standard(
|
||||
dispatch_output: DispatchOutput,
|
||||
) -> TypeGuard[StandardDispatchOutput]:
|
||||
return dispatch_output.format.is_standard()
|
||||
|
||||
@staticmethod
|
||||
def format_is_deepep_normal(
|
||||
dispatch_output: DispatchOutput,
|
||||
) -> TypeGuard[DeepEPNormalOutput]:
|
||||
return dispatch_output.format.is_deepep_normal()
|
||||
|
||||
@staticmethod
|
||||
def format_is_deepep_ll(
|
||||
dispatch_output: DispatchOutput,
|
||||
) -> TypeGuard[DeepEPLLOutput]:
|
||||
return dispatch_output.format.is_deepep_ll()
|
||||
|
||||
@staticmethod
|
||||
def format_is_deepep(
|
||||
dispatch_output: DispatchOutput,
|
||||
) -> TypeGuard[Union[DeepEPNormalOutput, DeepEPLLOutput]]:
|
||||
return dispatch_output.format.is_deepep()
|
||||
|
||||
@staticmethod
|
||||
def format_is_ascent_ll(
|
||||
dispatch_output: DispatchOutput,
|
||||
) -> TypeGuard[AscendDeepEPLLOutput]:
|
||||
return dispatch_output.format.is_ascent_ll()
|
||||
|
||||
|
||||
class DispatchOutputFormat(Enum):
|
||||
standard = auto()
|
||||
deepep_normal = auto()
|
||||
deepep_ll = auto()
|
||||
|
||||
STANDARD = auto()
|
||||
DEEPEP_NORMAL = auto()
|
||||
DEEPEP_LL = auto()
|
||||
ASCENT_LL = auto()
|
||||
|
||||
def is_standard(self) -> bool:
|
||||
return self == DispatchOutputFormat.standard
|
||||
return self == DispatchOutputFormat.STANDARD
|
||||
|
||||
def is_deepep_normal(self) -> bool:
|
||||
return self == DispatchOutputFormat.deepep_normal
|
||||
return self == DispatchOutputFormat.DEEPEP_NORMAL
|
||||
|
||||
def is_deepep_ll(self) -> bool:
|
||||
return self == DispatchOutputFormat.deepep_ll
|
||||
return self == DispatchOutputFormat.DEEPEP_LL
|
||||
|
||||
def is_deepep(self) -> bool:
|
||||
return self in [
|
||||
DispatchOutputFormat.DEEPEP_NORMAL,
|
||||
DispatchOutputFormat.DEEPEP_LL,
|
||||
]
|
||||
|
||||
def is_ascent_ll(self) -> bool:
|
||||
return self == DispatchOutputFormat.ASCENT_LL
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
||||
@@ -2,27 +2,17 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Protocol,
|
||||
Tuple,
|
||||
Union,
|
||||
runtime_checkable,
|
||||
)
|
||||
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
|
||||
|
||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||
from sglang.srt.layers.moe import DeepEPMode, get_deepep_config, is_tbo_enabled
|
||||
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
|
||||
BaseDispatcher,
|
||||
BaseDispatcherConfig,
|
||||
DispatchOutput,
|
||||
DispatchOutputFormat,
|
||||
)
|
||||
from sglang.srt.layers.moe.utils import DeepEPMode
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
get_int_env_var,
|
||||
@@ -72,7 +62,7 @@ class DeepEPNormalOutput(NamedTuple):
|
||||
|
||||
@property
|
||||
def format(self) -> DispatchOutputFormat:
|
||||
return DispatchOutputFormat.deepep_normal
|
||||
return DispatchOutputFormat.DEEPEP_NORMAL
|
||||
|
||||
|
||||
class DeepEPLLOutput(NamedTuple):
|
||||
@@ -86,7 +76,7 @@ class DeepEPLLOutput(NamedTuple):
|
||||
|
||||
@property
|
||||
def format(self) -> DispatchOutputFormat:
|
||||
return DispatchOutputFormat.deepep_ll
|
||||
return DispatchOutputFormat.DEEPEP_LL
|
||||
|
||||
|
||||
class AscendDeepEPLLOutput(NamedTuple):
|
||||
@@ -101,7 +91,7 @@ class AscendDeepEPLLOutput(NamedTuple):
|
||||
|
||||
@property
|
||||
def format(self) -> DispatchOutputFormat:
|
||||
return DispatchOutputFormat.deepep_ll
|
||||
return DispatchOutputFormat.ASCENT_LL
|
||||
|
||||
|
||||
assert isinstance(DeepEPNormalOutput, DispatchOutput)
|
||||
@@ -128,8 +118,8 @@ class DeepEPBuffer:
|
||||
hidden_size: int,
|
||||
param_bytes: int,
|
||||
deepep_mode: DeepEPMode,
|
||||
num_max_dispatch_tokens_per_rank: int = None,
|
||||
num_experts: int = None,
|
||||
num_max_dispatch_tokens_per_rank: int = -1,
|
||||
num_experts: int = -1,
|
||||
):
|
||||
if cls._buffer is not None:
|
||||
return cls._buffer
|
||||
@@ -156,8 +146,8 @@ class DeepEPBuffer:
|
||||
num_rdma_bytes,
|
||||
)
|
||||
if deepep_mode.enable_low_latency():
|
||||
assert num_max_dispatch_tokens_per_rank is not None
|
||||
assert num_experts is not None and num_experts % group.size() == 0
|
||||
assert num_max_dispatch_tokens_per_rank != -1
|
||||
assert num_experts != -1 and num_experts % group.size() == 0
|
||||
num_rdma_bytes = max(
|
||||
Buffer.get_low_latency_rdma_size_hint(
|
||||
num_max_dispatch_tokens_per_rank,
|
||||
@@ -181,7 +171,7 @@ class DeepEPBuffer:
|
||||
).multi_processor_count
|
||||
if (
|
||||
(deepep_mode != DeepEPMode.LOW_LATENCY)
|
||||
and not global_server_args_dict["enable_two_batch_overlap"]
|
||||
and not is_tbo_enabled()
|
||||
and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
|
||||
):
|
||||
logger.warning(
|
||||
@@ -226,7 +216,7 @@ class DeepEPConfig(BaseDispatcherConfig):
|
||||
_instance = None
|
||||
|
||||
def __init__(self):
|
||||
config_str = global_server_args_dict["deepep_config"]
|
||||
config_str = get_deepep_config()
|
||||
if config_str:
|
||||
config_parsed = load_json_config(config_str)
|
||||
if torch.distributed.get_rank() == 0:
|
||||
|
||||
@@ -13,7 +13,7 @@ class StandardDispatchOutput(NamedTuple):
|
||||
|
||||
@property
|
||||
def format(self) -> DispatchOutputFormat:
|
||||
return DispatchOutputFormat.standard
|
||||
return DispatchOutputFormat.STANDARD
|
||||
|
||||
|
||||
assert isinstance(StandardDispatchOutput, DispatchOutput)
|
||||
|
||||
@@ -14,9 +14,18 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import Callable, NamedTuple, Optional, Protocol, runtime_checkable
|
||||
from typing import (
|
||||
Callable,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Protocol,
|
||||
TypeGuard,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -28,7 +37,10 @@ from sglang.srt.eplb.expert_location_dispatch import (
|
||||
ExpertLocationDispatchInfo,
|
||||
topk_ids_logical_to_physical,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.layers.moe import (
|
||||
get_moe_runner_backend,
|
||||
should_use_flashinfer_trtllm_moe,
|
||||
)
|
||||
from sglang.srt.utils import (
|
||||
cpu_has_amx_support,
|
||||
get_bool_env_var,
|
||||
@@ -43,6 +55,7 @@ try:
|
||||
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
|
||||
except ImportError:
|
||||
pass
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
@@ -65,13 +78,48 @@ if _use_aiter:
|
||||
if _is_npu:
|
||||
import torch_npu
|
||||
|
||||
# -------------------------------- TopKConfig ---------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class TopKConfig:
|
||||
top_k: int
|
||||
use_grouped_topk: bool = False
|
||||
topk_group: int = 0
|
||||
num_expert_group: int = 0
|
||||
renormalize: bool = True
|
||||
num_fused_shared_experts: int = 0
|
||||
custom_routing_function: Optional[Callable] = None
|
||||
correction_bias: Optional[torch.Tensor] = None
|
||||
torch_native: bool = False
|
||||
routed_scaling_factor: Optional[float] = None
|
||||
apply_routed_scaling_factor_on_output: bool = False
|
||||
|
||||
|
||||
# -------------------------------- TopKOutput ---------------------------------------
|
||||
|
||||
|
||||
class TopKOutputChecker:
|
||||
|
||||
@staticmethod
|
||||
def format_is_standard(topk_output: TopKOutput) -> TypeGuard[StandardTopKOutput]:
|
||||
return topk_output.format.is_standard()
|
||||
|
||||
@staticmethod
|
||||
def format_is_triton_kernel(
|
||||
topk_output: TopKOutput,
|
||||
) -> TypeGuard[TritonKernelTopKOutput]:
|
||||
return topk_output.format.is_triton_kernel()
|
||||
|
||||
@staticmethod
|
||||
def format_is_bypassed(topk_output: TopKOutput) -> TypeGuard[BypassedTopKOutput]:
|
||||
return topk_output.format.is_bypassed()
|
||||
|
||||
|
||||
class TopKOutputFormat(Enum):
|
||||
STANDARD = auto()
|
||||
TRITON_KERNEL = auto()
|
||||
BYPASSED = auto()
|
||||
|
||||
def is_standard(self) -> bool:
|
||||
return self == TopKOutputFormat.STANDARD
|
||||
@@ -79,6 +127,9 @@ class TopKOutputFormat(Enum):
|
||||
def is_triton_kernel(self) -> bool:
|
||||
return self == TopKOutputFormat.TRITON_KERNEL
|
||||
|
||||
def is_bypassed(self) -> bool:
|
||||
return self == TopKOutputFormat.BYPASSED
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class TopKOutput(Protocol):
|
||||
@@ -114,6 +165,20 @@ class TritonKernelTopKOutput(NamedTuple):
|
||||
return TopKOutputFormat.TRITON_KERNEL
|
||||
|
||||
|
||||
class BypassedTopKOutput(NamedTuple):
|
||||
"""Bypassed top-k output format."""
|
||||
|
||||
hidden_states: torch.Tensor
|
||||
router_logits: torch.Tensor
|
||||
topk_config: TopKConfig
|
||||
num_token_non_padded: Optional[torch.Tensor] = None
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None
|
||||
|
||||
@property
|
||||
def format(self) -> TopKOutputFormat:
|
||||
return TopKOutputFormat.BYPASSED
|
||||
|
||||
|
||||
# -------------------------------- TopK ---------------------------------------
|
||||
|
||||
|
||||
@@ -124,8 +189,8 @@ class TopK(CustomOp):
|
||||
top_k: int,
|
||||
*,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: int = 0,
|
||||
num_expert_group: int = 0,
|
||||
renormalize: bool = True,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
@@ -136,19 +201,23 @@ class TopK(CustomOp):
|
||||
# NOTE: scoring_func is not used for now, but we keep it for future use
|
||||
# see https://github.com/sgl-project/sglang/pull/4505 for more details
|
||||
super().__init__()
|
||||
|
||||
if use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
self.top_k = top_k
|
||||
self.use_grouped_topk = use_grouped_topk
|
||||
self.renormalize = renormalize
|
||||
self.topk_group = topk_group
|
||||
self.num_expert_group = num_expert_group
|
||||
self.num_fused_shared_experts = num_fused_shared_experts
|
||||
self.custom_routing_function = custom_routing_function
|
||||
self.correction_bias = correction_bias
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
|
||||
self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
|
||||
self.topk_config = TopKConfig(
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
@@ -158,20 +227,11 @@ class TopK(CustomOp):
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
) -> TopKOutput:
|
||||
torch_native = True
|
||||
self.topk_config.torch_native = True
|
||||
return select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
correction_bias=self.correction_bias,
|
||||
torch_native=torch_native,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
topk_config=self.topk_config,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||
)
|
||||
@@ -187,24 +247,28 @@ class TopK(CustomOp):
|
||||
if self.use_triton_kernels:
|
||||
# renormalize=True is equivalent to sm_first=False
|
||||
routing_data, gather_idx, scatter_idx = routing(
|
||||
router_logits, self.top_k, sm_first=not self.renormalize
|
||||
router_logits,
|
||||
self.topk_config.top_k,
|
||||
sm_first=not self.topk_config.renormalize,
|
||||
)
|
||||
return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
|
||||
elif (
|
||||
should_use_flashinfer_trtllm_moe()
|
||||
or get_moe_runner_backend().is_flashinfer_mxfp4()
|
||||
):
|
||||
return BypassedTopKOutput(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
topk_config=self.topk_config,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||
)
|
||||
else:
|
||||
torch_native = False
|
||||
self.topk_config.torch_native = False
|
||||
return select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
correction_bias=self.correction_bias,
|
||||
torch_native=torch_native,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
topk_config=self.topk_config,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||
)
|
||||
@@ -220,15 +284,7 @@ class TopK(CustomOp):
|
||||
return select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
correction_bias=self.correction_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
topk_config=self.topk_config,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||
)
|
||||
@@ -244,35 +300,29 @@ class TopK(CustomOp):
|
||||
global_num_experts = router_logits.shape[-1]
|
||||
|
||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||
if global_num_experts == 256:
|
||||
if global_num_experts == 256 and self.topk_config.renormalize is False:
|
||||
|
||||
routed_scaling_factor = self.topk_config.routed_scaling_factor or 1
|
||||
router_logits = router_logits.to(torch.float32)
|
||||
|
||||
return torch_npu.npu_moe_gating_top_k(
|
||||
router_logits,
|
||||
k=self.top_k,
|
||||
bias=self.correction_bias.to(torch.float32),
|
||||
k_group=self.topk_group,
|
||||
group_count=self.num_expert_group,
|
||||
k=self.topk_config.top_k,
|
||||
bias=self.topk_config.correction_bias.to(torch.float32),
|
||||
k_group=self.topk_config.topk_group,
|
||||
group_count=self.topk_config.num_expert_group,
|
||||
group_select_mode=1,
|
||||
renorm=0,
|
||||
norm_type=1,
|
||||
routed_scaling_factor=1,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
eps=float(1e-20),
|
||||
)
|
||||
else:
|
||||
torch_native = True
|
||||
self.topk_config.torch_native = True
|
||||
return select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
correction_bias=self.correction_bias,
|
||||
torch_native=torch_native,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
topk_config=self.topk_config,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||
)
|
||||
@@ -670,20 +720,23 @@ else:
|
||||
def select_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
topk_config: TopKConfig,
|
||||
*,
|
||||
use_grouped_topk: bool = False,
|
||||
renormalize: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
torch_native: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
) -> TopKOutput:
|
||||
) -> StandardTopKOutput:
|
||||
|
||||
top_k = topk_config.top_k
|
||||
use_grouped_topk = topk_config.use_grouped_topk
|
||||
topk_group = topk_config.topk_group
|
||||
num_expert_group = topk_config.num_expert_group
|
||||
renormalize = topk_config.renormalize
|
||||
num_fused_shared_experts = topk_config.num_fused_shared_experts
|
||||
custom_routing_function = topk_config.custom_routing_function
|
||||
correction_bias = topk_config.correction_bias
|
||||
torch_native = topk_config.torch_native
|
||||
routed_scaling_factor = topk_config.routed_scaling_factor
|
||||
|
||||
router_logits, correction_bias = (
|
||||
expert_location_dispatch.transform_select_experts_inputs(
|
||||
router_logits=router_logits,
|
||||
|
||||
@@ -1,55 +1,80 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from packaging import version as pkg_version
|
||||
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.utils import logger
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def should_use_flashinfer_trtllm_moe():
|
||||
result = global_server_args_dict["enable_flashinfer_trtllm_moe"] and (
|
||||
not importlib.util.find_spec("flashinfer")
|
||||
or pkg_version.parse(__import__("flashinfer").__version__)
|
||||
>= pkg_version.parse("0.2.9rc1")
|
||||
)
|
||||
return result
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
|
||||
class MoeA2ABackend(Enum):
|
||||
|
||||
STANDARD = ("standard", "none")
|
||||
NONE = "none"
|
||||
DEEPEP = "deepep"
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value):
|
||||
if value is None:
|
||||
return cls.STANDARD
|
||||
return cls.NONE
|
||||
for member in cls:
|
||||
if value in member.value:
|
||||
if value == member.value:
|
||||
return member
|
||||
raise ValueError(f"No {cls.__name__} member for value {value}")
|
||||
|
||||
def is_none(self):
|
||||
return self == MoeA2ABackend.NONE
|
||||
|
||||
def is_deepep(self):
|
||||
return self == MoeA2ABackend.DEEPEP
|
||||
|
||||
def is_standard(self):
|
||||
return self == MoeA2ABackend.STANDARD
|
||||
|
||||
class MoeRunnerBackend(Enum):
|
||||
|
||||
AUTO = "auto"
|
||||
TRITON = "triton"
|
||||
TRITON_KERNEL = "triton_kernel"
|
||||
FLASHINFER = "flashinfer_trtllm"
|
||||
FLASHINFER_CUTLASS = "flashinfer_cutlass"
|
||||
FLASHINFER_MXFP4 = "flashinfer_mxfp4"
|
||||
|
||||
def is_auto(self):
|
||||
return self == MoeRunnerBackend.AUTO
|
||||
|
||||
def is_triton(self):
|
||||
return self == MoeRunnerBackend.TRITON
|
||||
|
||||
def is_triton_kernel(self):
|
||||
return self == MoeRunnerBackend.TRITON_KERNEL
|
||||
|
||||
def is_flashinfer_trtllm(self):
|
||||
return self == MoeRunnerBackend.FLASHINFER
|
||||
|
||||
def is_flashinfer_cutlass(self):
|
||||
return self == MoeRunnerBackend.FLASHINFER_CUTLASS
|
||||
|
||||
def is_flashinfer_mxfp4(self):
|
||||
return self == MoeRunnerBackend.FLASHINFER_MXFP4
|
||||
|
||||
|
||||
class DeepEPMode(Enum):
|
||||
|
||||
NORMAL = "normal"
|
||||
LOW_LATENCY = "low_latency"
|
||||
AUTO = "auto"
|
||||
|
||||
def enable_normal(self):
|
||||
def enable_normal(self) -> bool:
|
||||
return self in [DeepEPMode.NORMAL, DeepEPMode.AUTO]
|
||||
|
||||
def enable_low_latency(self):
|
||||
def enable_low_latency(self) -> bool:
|
||||
return self in [DeepEPMode.LOW_LATENCY, DeepEPMode.AUTO]
|
||||
|
||||
def resolve(self, is_extend_in_batch: bool):
|
||||
def resolve(self, is_extend_in_batch: bool) -> DeepEPMode:
|
||||
if self != DeepEPMode.AUTO:
|
||||
return self
|
||||
|
||||
@@ -57,3 +82,96 @@ class DeepEPMode(Enum):
|
||||
return DeepEPMode.NORMAL
|
||||
else:
|
||||
return DeepEPMode.LOW_LATENCY
|
||||
|
||||
def is_normal(self) -> bool:
|
||||
return self == DeepEPMode.NORMAL
|
||||
|
||||
def is_low_latency(self) -> bool:
|
||||
return self == DeepEPMode.LOW_LATENCY
|
||||
|
||||
def is_auto(self) -> bool:
|
||||
return self == DeepEPMode.AUTO
|
||||
|
||||
|
||||
MOE_A2A_BACKEND: Optional[MoeA2ABackend] = None
|
||||
MOE_RUNNER_BACKEND: Optional[MoeRunnerBackend] = None
|
||||
DEEPEP_MODE: Optional[DeepEPMode] = None
|
||||
IS_TBO_ENABLED: Optional[bool] = None
|
||||
TBO_TOKEN_DISTRIBUTION_THRESHOLD: Optional[float] = None
|
||||
DEEPEP_CONFIG: Optional[str] = None
|
||||
|
||||
|
||||
def initialize_moe_config(server_args: ServerArgs):
|
||||
global MOE_A2A_BACKEND
|
||||
global MOE_RUNNER_BACKEND
|
||||
global DEEPEP_MODE
|
||||
global DEEPEP_CONFIG
|
||||
global IS_TBO_ENABLED
|
||||
global TBO_TOKEN_DISTRIBUTION_THRESHOLD
|
||||
|
||||
MOE_A2A_BACKEND = MoeA2ABackend(server_args.moe_a2a_backend)
|
||||
MOE_RUNNER_BACKEND = MoeRunnerBackend(server_args.moe_runner_backend)
|
||||
DEEPEP_MODE = DeepEPMode(server_args.deepep_mode)
|
||||
DEEPEP_CONFIG = server_args.deepep_config or ""
|
||||
IS_TBO_ENABLED = server_args.enable_two_batch_overlap
|
||||
TBO_TOKEN_DISTRIBUTION_THRESHOLD = server_args.tbo_token_distribution_threshold
|
||||
|
||||
|
||||
def get_moe_a2a_backend() -> MoeA2ABackend:
|
||||
global MOE_A2A_BACKEND
|
||||
if MOE_A2A_BACKEND is None:
|
||||
logger.warning("MOE_A2A_BACKEND is not initialized, using default backend")
|
||||
MOE_A2A_BACKEND = MoeA2ABackend(None)
|
||||
return MOE_A2A_BACKEND
|
||||
|
||||
|
||||
def get_moe_runner_backend() -> MoeRunnerBackend:
|
||||
global MOE_RUNNER_BACKEND
|
||||
if MOE_RUNNER_BACKEND is None:
|
||||
logger.warning("MOE_RUNNER_BACKEND is not initialized, using triton backend")
|
||||
MOE_RUNNER_BACKEND = MoeRunnerBackend("triton")
|
||||
return MOE_RUNNER_BACKEND
|
||||
|
||||
|
||||
def get_deepep_mode() -> DeepEPMode:
|
||||
global DEEPEP_MODE
|
||||
if DEEPEP_MODE is None:
|
||||
logger.warning("DEEPEP_MODE is not initialized, using auto mode")
|
||||
DEEPEP_MODE = DeepEPMode("auto")
|
||||
return DEEPEP_MODE
|
||||
|
||||
|
||||
def get_deepep_config() -> str:
|
||||
global DEEPEP_CONFIG
|
||||
if DEEPEP_CONFIG is None:
|
||||
logger.warning("DEEPEP_CONFIG is not initialized, using default config")
|
||||
DEEPEP_CONFIG = ""
|
||||
return DEEPEP_CONFIG
|
||||
|
||||
|
||||
def is_tbo_enabled() -> bool:
|
||||
global IS_TBO_ENABLED
|
||||
if IS_TBO_ENABLED is None:
|
||||
logger.warning("IS_TBO_ENABLED is not initialized, using False")
|
||||
IS_TBO_ENABLED = False
|
||||
return IS_TBO_ENABLED
|
||||
|
||||
|
||||
def get_tbo_token_distribution_threshold() -> float:
|
||||
global TBO_TOKEN_DISTRIBUTION_THRESHOLD
|
||||
if TBO_TOKEN_DISTRIBUTION_THRESHOLD is None:
|
||||
logger.warning(
|
||||
"TBO_TOKEN_DISTRIBUTION_THRESHOLD is not initialized, using 0.48"
|
||||
)
|
||||
TBO_TOKEN_DISTRIBUTION_THRESHOLD = 0.48
|
||||
return TBO_TOKEN_DISTRIBUTION_THRESHOLD
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def should_use_flashinfer_trtllm_moe():
|
||||
result = get_moe_runner_backend().is_flashinfer_trtllm() and (
|
||||
not importlib.util.find_spec("flashinfer")
|
||||
or pkg_version.parse(__import__("flashinfer").__version__)
|
||||
>= pkg_version.parse("0.2.9rc1")
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -33,7 +33,8 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||
from sglang.srt.layers.quantization.utils import get_scalar_types, replace_parameter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
||||
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
|
||||
@@ -739,13 +740,12 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
**kwargs,
|
||||
topk_output: StandardTopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
assert (
|
||||
moe_runner_config.activation == "silu"
|
||||
), "Only SiLU activation is supported."
|
||||
|
||||
# The input must currently be float16
|
||||
orig_dtype = x.dtype
|
||||
|
||||
@@ -9,6 +9,7 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
|
||||
|
||||
@@ -100,12 +101,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
@@ -22,6 +22,7 @@ from sglang.srt.layers.quantization.utils import is_layer_skipped
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
@@ -348,12 +349,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
|
||||
@@ -363,15 +359,11 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_output=topk_output,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
moe_runner_config=moe_runner_config,
|
||||
use_int8_w8a8=True,
|
||||
w1_scale=(layer.w13_weight_scale_inv),
|
||||
w2_scale=(layer.w2_weight_scale_inv),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
no_combine=no_combine,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
@@ -23,6 +23,7 @@ from sglang.srt.utils import is_cpu, is_cuda, is_hip, is_npu, set_weight_attrs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
|
||||
CompressedTensorsConfig,
|
||||
@@ -269,12 +270,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.fused_moe_triton import fused_experts
|
||||
|
||||
@@ -283,8 +279,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_output=topk_output,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
moe_runner_config=moe_runner_config,
|
||||
use_fp8_w8a8=True,
|
||||
per_channel_quant=self.weight_quant.strategy
|
||||
== QuantizationStrategy.CHANNEL,
|
||||
@@ -292,8 +287,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
|
||||
@@ -601,12 +594,12 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
**kwargs,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
assert (
|
||||
moe_runner_config.activation == "silu"
|
||||
), "Only SiLU activation is supported."
|
||||
|
||||
topk_weights, topk_ids, router_logits = topk_output
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ from sglang.srt.utils import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -220,22 +221,10 @@ class MxFp4LinearMethod(LinearMethodBase):
|
||||
return out
|
||||
|
||||
|
||||
class MxFp4MoEMethod:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if not hasattr(cls, "_initialized"):
|
||||
original_init = cls.__init__
|
||||
new_cls = type(
|
||||
cls.__name__,
|
||||
(FusedMoEMethodBase,),
|
||||
{
|
||||
"__init__": original_init,
|
||||
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
||||
},
|
||||
)
|
||||
obj = super(new_cls, new_cls).__new__(new_cls)
|
||||
obj.__init__(*args, **kwargs)
|
||||
return obj
|
||||
return super().__new__(cls)
|
||||
class MxFp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, quant_config: Mxfp4Config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
@staticmethod
|
||||
def get_moe_method(
|
||||
@@ -364,12 +353,7 @@ class W4A4MXFp4MoEDynamicMethod(MxFp4MoEMethod):
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
|
||||
@@ -383,7 +367,9 @@ class W4A4MXFp4MoEDynamicMethod(MxFp4MoEMethod):
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
activation=(
|
||||
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
|
||||
ActivationType.Silu
|
||||
if moe_runner_config.activation == "silu"
|
||||
else ActivationType.Gelu
|
||||
),
|
||||
doweight_stage1=False,
|
||||
)
|
||||
@@ -497,12 +483,7 @@ class W4A4MXFp4MoEStaticMethod(MxFp4MoEMethod):
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
|
||||
@@ -516,7 +497,9 @@ class W4A4MXFp4MoEStaticMethod(MxFp4MoEMethod):
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
activation=(
|
||||
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
|
||||
ActivationType.Silu
|
||||
if moe_runner_config.activation == "silu"
|
||||
else ActivationType.Gelu
|
||||
),
|
||||
doweight_stage1=False,
|
||||
)
|
||||
|
||||
@@ -79,6 +79,7 @@ from sglang.srt.utils import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
||||
|
||||
@@ -982,12 +983,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
|
||||
@@ -996,7 +992,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
x, topk_weights = apply_topk_weights_cpu(
|
||||
apply_router_weight_on_input, topk_weights, x
|
||||
moe_runner_config.apply_router_weight_on_input, topk_weights, x
|
||||
)
|
||||
|
||||
return torch.ops.sgl_kernel.fused_experts_cpu(
|
||||
@@ -1021,8 +1017,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer,
|
||||
x,
|
||||
topk_output,
|
||||
activation,
|
||||
no_combine,
|
||||
moe_runner_config.activation,
|
||||
moe_runner_config.no_combine,
|
||||
)
|
||||
if ret is not None:
|
||||
return ret
|
||||
@@ -1060,8 +1056,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
use_fp8_blockscale=True,
|
||||
)
|
||||
# TODO: Fuse into select_experts
|
||||
if routed_scaling_factor is not None:
|
||||
output *= routed_scaling_factor
|
||||
if moe_runner_config.routed_scaling_factor is not None:
|
||||
output *= moe_runner_config.routed_scaling_factor
|
||||
return output
|
||||
# Expert fusion with FP8 quantization
|
||||
return fused_experts(
|
||||
@@ -1069,9 +1065,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_output=topk_output,
|
||||
inplace=inplace and not no_combine,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
moe_runner_config=moe_runner_config,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=(
|
||||
layer.w13_weight_scale_inv
|
||||
@@ -1084,26 +1078,32 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
no_combine=no_combine,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
def apply_with_router_logits(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
topk_output: TopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
|
||||
activation = moe_runner_config.activation
|
||||
routed_scaling_factor = moe_runner_config.routed_scaling_factor
|
||||
|
||||
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
|
||||
|
||||
from sglang.srt.layers.moe.topk import TopKOutputChecker
|
||||
|
||||
assert TopKOutputChecker.format_is_bypassed(topk_output)
|
||||
router_logits = topk_output.router_logits
|
||||
topk_config = topk_output.topk_config
|
||||
assert (
|
||||
activation == "silu"
|
||||
), "Only silu is supported for flashinfer blockscale fp8 moe"
|
||||
a_q, a_sf = per_token_group_quant_fp8(x, self.quant_config.weight_block_size[1])
|
||||
# NOTE: scales of hidden states have to be transposed!
|
||||
a_sf_t = a_sf.t().contiguous()
|
||||
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
|
||||
|
||||
return trtllm_fp8_block_scale_moe(
|
||||
routing_logits=router_logits.to(torch.float32),
|
||||
@@ -1115,9 +1115,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
gemm2_weights=layer.w2_weight,
|
||||
gemm2_weights_scale=layer.w2_weight_scale_inv,
|
||||
num_experts=layer.num_experts,
|
||||
top_k=layer.top_k,
|
||||
n_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
top_k=topk_config.top_k,
|
||||
n_group=topk_config.num_expert_group,
|
||||
topk_group=topk_config.topk_group,
|
||||
intermediate_size=layer.w2_weight.shape[2],
|
||||
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
|
||||
local_num_experts=layer.num_local_experts,
|
||||
|
||||
@@ -113,6 +113,7 @@ def normalize_e4m3fn_to_e4m3fnuz(
|
||||
return weight, weight_scale, input_scale
|
||||
|
||||
|
||||
# TODO(ch-wan): define these backends in --moe-runner-backend
|
||||
def cutlass_block_fp8_supported() -> bool:
|
||||
if not get_bool_env_var("SGLANG_SUPPORT_CUTLASS_BLOCK_FP8"):
|
||||
return False
|
||||
|
||||
@@ -44,6 +44,7 @@ from sglang.srt.layers.quantization.utils import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
|
||||
from sglang.srt.utils import is_cuda
|
||||
@@ -1056,13 +1057,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
**kwargs,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
# Delay the import to avoid circular dependency
|
||||
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
assert (
|
||||
moe_runner_config.activation == "silu"
|
||||
), "Only SiLU activation is supported."
|
||||
|
||||
# The input must currently be float16
|
||||
orig_dtype = x.dtype
|
||||
|
||||
@@ -28,6 +28,7 @@ from sglang.srt.utils import get_device_capability, is_cuda
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
|
||||
try:
|
||||
from vllm import _custom_ops as ops
|
||||
@@ -216,13 +217,13 @@ def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
|
||||
)[0]
|
||||
|
||||
|
||||
def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
|
||||
def check_moe_marlin_supports_layer(layer: FusedMoE, group_size: int) -> bool:
|
||||
hidden_size = layer.hidden_size
|
||||
intermediate_size_per_partition = layer.intermediate_size_per_partition
|
||||
# apply_router_weight_on_input is not supported for moe marlin
|
||||
supports_router_weight = not layer.apply_router_weight_on_input
|
||||
supports_router_weight = not layer.moe_runner_config.apply_router_weight_on_input
|
||||
# moe marlin requires the activation to be silu
|
||||
supports_activation = layer.activation == "silu"
|
||||
supports_activation = layer.moe_runner_config.activation == "silu"
|
||||
|
||||
# gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size)
|
||||
# down: (n, k) = (hidden_size, intermediate_size_per_partition)
|
||||
|
||||
@@ -7,8 +7,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from sglang.srt.layers.moe import should_use_flashinfer_trtllm_moe
|
||||
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
|
||||
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
|
||||
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
@@ -30,10 +30,11 @@ from sglang.srt.layers.quantization.utils import (
|
||||
requantize_with_max_scale,
|
||||
)
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.utils import is_cuda, next_power_of_2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
|
||||
if is_cuda():
|
||||
@@ -422,12 +423,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
|
||||
@@ -436,15 +432,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_output=topk_output,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
moe_runner_config=moe_runner_config,
|
||||
use_fp8_w8a8=True,
|
||||
per_channel_quant=False, # ModelOpt uses per-tensor quantization
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
no_combine=no_combine,
|
||||
)
|
||||
|
||||
|
||||
@@ -741,8 +735,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
@property
|
||||
def enable_flashinfer_cutlass_moe(self) -> bool:
|
||||
from sglang.srt.layers.moe import get_moe_runner_backend
|
||||
|
||||
"""Access the global enable_flashinfer_cutlass_moe setting."""
|
||||
return global_server_args_dict.get("enable_flashinfer_cutlass_moe", False)
|
||||
return get_moe_runner_backend().is_flashinfer_cutlass()
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -1160,21 +1156,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
ep_rank: Optional[int] = None,
|
||||
ep_size: Optional[int] = None,
|
||||
tp_rank: Optional[int] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
assert (
|
||||
moe_runner_config.activation == "silu"
|
||||
), "Only SiLU activation is supported."
|
||||
|
||||
# Check if this is a FlashInferFP4MoE layer that should handle its own forward
|
||||
if hasattr(layer, "gemm1_weights_fp4_shuffled"):
|
||||
@@ -1183,7 +1172,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
if self.enable_flashinfer_cutlass_moe:
|
||||
assert (
|
||||
not apply_router_weight_on_input
|
||||
not moe_runner_config.apply_router_weight_on_input
|
||||
), "apply_router_weight_on_input is not supported for Flashinfer"
|
||||
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
|
||||
# and fp4 quantized weights loaded from the checkpoint
|
||||
@@ -1205,14 +1194,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_blockscale_swizzled.view(torch.int32),
|
||||
layer.g2_alphas,
|
||||
],
|
||||
ep_size=ep_size,
|
||||
ep_rank=ep_rank,
|
||||
tp_size=tp_size,
|
||||
tp_rank=tp_rank,
|
||||
ep_size=layer.moe_ep_size,
|
||||
ep_rank=layer.moe_ep_rank,
|
||||
tp_size=layer.moe_tp_size,
|
||||
tp_rank=layer.moe_tp_rank,
|
||||
tune_max_num_tokens=next_power_of_2(x.shape[0]),
|
||||
)[0]
|
||||
if routed_scaling_factor is not None:
|
||||
output *= routed_scaling_factor
|
||||
if moe_runner_config.routed_scaling_factor is not None:
|
||||
output *= moe_runner_config.routed_scaling_factor
|
||||
return output
|
||||
|
||||
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
||||
@@ -1231,8 +1220,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
params=layer.cutlass_moe_params,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
|
||||
).to(x.dtype)
|
||||
if routed_scaling_factor is not None:
|
||||
output *= routed_scaling_factor
|
||||
if moe_runner_config.routed_scaling_factor is not None:
|
||||
output *= moe_runner_config.routed_scaling_factor
|
||||
return output
|
||||
|
||||
@@ -22,6 +22,7 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
|
||||
|
||||
@@ -353,17 +354,14 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
# avoid circular import
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
assert (
|
||||
moe_runner_config.activation == "silu"
|
||||
), "Only SiLU activation is supported."
|
||||
|
||||
weight_bits = self.quant_config.weight_bits
|
||||
has_zp = self.quant_config.has_zp
|
||||
@@ -373,8 +371,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
||||
layer.w13_qweight,
|
||||
layer.w2_qweight,
|
||||
topk_output=topk_output,
|
||||
inplace=inplace,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
moe_runner_config=moe_runner_config,
|
||||
use_int4_w4a16=weight_bits == 4,
|
||||
use_int8_w8a16=weight_bits == 8,
|
||||
w1_scale=layer.w13_scales,
|
||||
@@ -382,8 +379,6 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
||||
w1_zp=layer.w13_qzeros if has_zp else None,
|
||||
w2_zp=layer.w2_qzeros if has_zp else None,
|
||||
block_shape=[0, layer.group_size],
|
||||
no_combine=no_combine,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -486,16 +481,16 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
||||
)
|
||||
|
||||
if "w13_qzeros" in weight_name:
|
||||
tensor = loaded_weight.view(layer.tp_size, -1, loaded_weight.size(1))[
|
||||
tp_rank
|
||||
]
|
||||
tensor = loaded_weight.view(
|
||||
layer.moe_tp_size, -1, loaded_weight.size(1)
|
||||
)[tp_rank]
|
||||
if shard_id == "w1":
|
||||
param.data[expert_id, : shard_size // 2] = tensor
|
||||
else:
|
||||
param.data[expert_id, shard_size // 2 :] = tensor
|
||||
elif "w2_qzeros" in weight_name:
|
||||
param.data[expert_id] = loaded_weight.view(
|
||||
loaded_weight.size(0), layer.tp_size, -1
|
||||
loaded_weight.size(0), layer.moe_tp_size, -1
|
||||
)[:, tp_rank]
|
||||
else:
|
||||
weight_loader(param, loaded_weight, weight_name, shard_id, expert_id)
|
||||
|
||||
@@ -16,14 +16,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
import triton.language as tl
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from sglang.srt.layers.moe.utils import get_moe_runner_backend
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
QuantizationConfig,
|
||||
@@ -31,7 +30,6 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
)
|
||||
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
||||
from sglang.srt.layers.utils import is_sm100_supported
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.utils import (
|
||||
direct_register_custom_op,
|
||||
get_bool_env_var,
|
||||
@@ -60,6 +58,7 @@ if is_flashinfer_available():
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
|
||||
OCP_MX_BLOCK_SIZE = 32
|
||||
@@ -218,15 +217,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
self,
|
||||
prefix: str,
|
||||
):
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.prefix = prefix
|
||||
self.topk_indices_dtype = None
|
||||
self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
|
||||
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
|
||||
self.with_bias = False
|
||||
self.use_flashinfer = global_server_args_dict["enable_flashinfer_mxfp4_moe"]
|
||||
self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4()
|
||||
|
||||
self.triton_kernel_moe_forward = None
|
||||
self.triton_kernel_moe_with_bias_forward = None
|
||||
@@ -348,6 +345,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
logger,
|
||||
f"Shuffling MoE weights for FlashInfer MXFP4 moe kernel (layer: {self.prefix}), it might take a while...",
|
||||
)
|
||||
# TODO: these values are hardcoded for now, we need to get them from the model
|
||||
layer.gemm1_alpha = Parameter(
|
||||
torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
|
||||
requires_grad=False,
|
||||
@@ -573,14 +571,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
activation_alpha: Optional[float] = None,
|
||||
swiglu_limit: Optional[float] = None,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
if self.use_flashinfer:
|
||||
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
|
||||
@@ -637,9 +628,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
b1=layer.w13_weight_bias,
|
||||
b2=layer.w2_weight_bias,
|
||||
topk_output=topk_output,
|
||||
activation=activation,
|
||||
activation_alpha=activation_alpha,
|
||||
swiglu_limit=swiglu_limit,
|
||||
moe_runner_config=moe_runner_config,
|
||||
)
|
||||
else:
|
||||
return self.triton_kernel_moe_forward(
|
||||
@@ -647,6 +636,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_output=topk_output,
|
||||
moe_runner_config=moe_runner_config,
|
||||
)
|
||||
else:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
@@ -656,13 +646,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_output=topk_output,
|
||||
moe_runner_config=moe_runner_config,
|
||||
b1=layer.w13_weight_bias,
|
||||
b2=layer.w2_weight_bias,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
no_combine=no_combine,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
activation_alpha=activation_alpha,
|
||||
swiglu_limit=swiglu_limit,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional
|
||||
import importlib.util
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -24,7 +24,7 @@ from sglang.srt.utils import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
|
||||
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
||||
@@ -221,31 +221,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
activation_alpha: Optional[float] = None,
|
||||
swiglu_limit: Optional[float] = None,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
kwargs = {}
|
||||
if activation_alpha is not None:
|
||||
kwargs["activation_alpha"] = activation_alpha
|
||||
if swiglu_limit is not None:
|
||||
kwargs["swiglu_limit"] = swiglu_limit
|
||||
|
||||
return self.forward(
|
||||
x=x,
|
||||
layer=layer,
|
||||
topk_output=topk_output,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
inplace=inplace,
|
||||
no_combine=no_combine,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
**kwargs,
|
||||
moe_runner_config=moe_runner_config,
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
@@ -253,18 +236,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
activation_alpha: Optional[float] = None,
|
||||
swiglu_limit: Optional[float] = None,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if self.use_triton_kernels:
|
||||
if self.with_bias:
|
||||
assert self.triton_kernel_moe_with_bias_forward is not None
|
||||
return self.triton_kernel_moe_with_bias_forward(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
@@ -272,24 +249,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
b1=layer.w13_weight_bias,
|
||||
b2=layer.w2_weight_bias,
|
||||
topk_output=topk_output,
|
||||
activation=activation,
|
||||
activation_alpha=activation_alpha,
|
||||
swiglu_limit=swiglu_limit,
|
||||
moe_runner_config=moe_runner_config,
|
||||
w1_pcg=None,
|
||||
w2_pcg=None,
|
||||
)
|
||||
else:
|
||||
assert self.triton_kernel_moe_forward is not None
|
||||
return self.triton_kernel_moe_forward(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_output=topk_output,
|
||||
moe_runner_config=moe_runner_config,
|
||||
)
|
||||
else:
|
||||
if _use_aiter:
|
||||
assert not no_combine, "unsupported"
|
||||
assert not moe_runner_config.no_combine, "unsupported"
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
if apply_router_weight_on_input:
|
||||
if moe_runner_config.apply_router_weight_on_input:
|
||||
assert (
|
||||
topk_weights.dim() == 2
|
||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||
@@ -309,7 +286,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
topk_ids,
|
||||
activation=(
|
||||
ActivationType.Silu
|
||||
if activation == "silu"
|
||||
if moe_runner_config.activation == "silu"
|
||||
else ActivationType.Gelu
|
||||
),
|
||||
)
|
||||
@@ -325,13 +302,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
b1=getattr(layer, "w13_weight_bias", None),
|
||||
b2=getattr(layer, "w2_weight_bias", None),
|
||||
topk_output=topk_output,
|
||||
inplace=inplace and not no_combine,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
no_combine=no_combine,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
activation_alpha=activation_alpha,
|
||||
swiglu_limit=swiglu_limit,
|
||||
moe_runner_config=moe_runner_config,
|
||||
)
|
||||
|
||||
def forward_cpu(
|
||||
@@ -339,21 +310,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
assert activation == "silu", f"activation = {activation} is not supported."
|
||||
assert (
|
||||
moe_runner_config.activation == "silu"
|
||||
), f"activation = {moe_runner_config.activation} is not supported."
|
||||
|
||||
if use_intel_amx_backend(layer) and not apply_router_weight_on_input:
|
||||
if (
|
||||
use_intel_amx_backend(layer)
|
||||
and not moe_runner_config.apply_router_weight_on_input
|
||||
):
|
||||
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
||||
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
x, topk_weights = apply_topk_weights_cpu(
|
||||
apply_router_weight_on_input, topk_weights, x
|
||||
moe_runner_config.apply_router_weight_on_input, topk_weights, x
|
||||
)
|
||||
return torch.ops.sgl_kernel.fused_experts_cpu(
|
||||
x,
|
||||
@@ -378,11 +349,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
layer,
|
||||
x,
|
||||
topk_output,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
inplace=inplace,
|
||||
no_combine=no_combine,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
moe_runner_config,
|
||||
)
|
||||
|
||||
def forward_npu(
|
||||
@@ -390,12 +357,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
|
||||
|
||||
@@ -403,11 +365,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
layer,
|
||||
x,
|
||||
topk_output,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
inplace=inplace,
|
||||
no_combine=no_combine,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
moe_runner_config,
|
||||
)
|
||||
|
||||
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
|
||||
|
||||
@@ -18,7 +18,9 @@ from sglang.srt.layers.quantization.utils import is_layer_skipped
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE, TopKOutput
|
||||
from sglang.srt.layers.moe import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
@@ -280,11 +282,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
self,
|
||||
layer: EPMoE,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
**kwargs,
|
||||
topk_output: StandardTopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# TODO(ch-wan): move it out of this class
|
||||
@@ -324,6 +323,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w13_input_scale,
|
||||
layer.w2_input_scale,
|
||||
)
|
||||
if routed_scaling_factor is not None:
|
||||
output *= routed_scaling_factor
|
||||
if moe_runner_config.routed_scaling_factor is not None:
|
||||
output *= moe_runner_config.routed_scaling_factor
|
||||
return output
|
||||
|
||||
@@ -26,7 +26,8 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
||||
|
||||
_is_fp8_fnuz = is_fp8_fnuz()
|
||||
|
||||
@@ -269,13 +270,8 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
topk_output: StandardTopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
|
||||
@@ -284,15 +280,11 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_output=topk_output,
|
||||
inplace=inplace,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
activation=activation,
|
||||
moe_runner_config=moe_runner_config,
|
||||
use_fp8_w8a8=True,
|
||||
per_channel_quant=True,
|
||||
w1_scale=(layer.w13_weight_scale),
|
||||
w2_scale=(layer.w2_weight_scale),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
no_combine=no_combine,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
@@ -49,6 +49,7 @@ from sglang.srt.utils import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
@@ -487,12 +488,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
|
||||
@@ -501,7 +497,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
x, topk_weights = apply_topk_weights_cpu(
|
||||
apply_router_weight_on_input, topk_weights, x
|
||||
moe_runner_config.apply_router_weight_on_input, topk_weights, x
|
||||
)
|
||||
return torch.ops.sgl_kernel.fused_experts_cpu(
|
||||
x,
|
||||
@@ -525,17 +521,13 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_output=topk_output,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
moe_runner_config=moe_runner_config,
|
||||
use_int8_w8a8=True,
|
||||
per_channel_quant=True,
|
||||
w1_scale=(layer.w13_weight_scale),
|
||||
w2_scale=(layer.w2_weight_scale),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
no_combine=no_combine,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
|
||||
@@ -982,7 +974,7 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
|
||||
layer,
|
||||
x,
|
||||
topk_output: TopKOutput,
|
||||
**kwargs,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
|
||||
@@ -52,6 +52,7 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
||||
ScheduleBatchDisaggregationDecodeMixin,
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
||||
from sglang.srt.layers.moe import is_tbo_enabled
|
||||
from sglang.srt.mem_cache.allocator import (
|
||||
BaseTokenToKVPoolAllocator,
|
||||
SWATokenToKVPoolAllocator,
|
||||
@@ -84,17 +85,10 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
||||
"device",
|
||||
"disable_chunked_prefix_cache",
|
||||
"disable_radix_cache",
|
||||
"enable_two_batch_overlap",
|
||||
"tbo_token_distribution_threshold",
|
||||
"enable_dp_lm_head",
|
||||
"moe_a2a_backend",
|
||||
"deepep_mode",
|
||||
"enable_flashinfer_cutlass_moe",
|
||||
"enable_flashinfer_trtllm_moe",
|
||||
"enable_flashinfer_allreduce_fusion",
|
||||
"moe_dense_tp_size",
|
||||
"ep_dispatch_algorithm",
|
||||
"deepep_config",
|
||||
"ep_num_redundant_experts",
|
||||
"enable_nan_detection",
|
||||
"flashinfer_mla_disable_ragged",
|
||||
@@ -107,8 +101,6 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
||||
"triton_attention_reduce_in_fp32",
|
||||
"num_reserved_decode_tokens",
|
||||
"weight_loader_disable_mmap",
|
||||
"enable_triton_kernel_moe",
|
||||
"enable_flashinfer_mxfp4_moe",
|
||||
"enable_multimodal",
|
||||
"enable_symm_mem",
|
||||
"quantization",
|
||||
|
||||
@@ -64,7 +64,7 @@ from sglang.srt.hf_transformers_utils import (
|
||||
)
|
||||
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
|
||||
from sglang.srt.layers.moe import initialize_moe_config
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
CloseSessionReqInput,
|
||||
@@ -245,6 +245,9 @@ class Scheduler(
|
||||
)
|
||||
)
|
||||
|
||||
# Init model config
|
||||
self.model_config = ModelConfig.from_server_args(server_args)
|
||||
|
||||
# Init inter-process communication
|
||||
context = zmq.Context(2)
|
||||
self.idle_sleeper = None
|
||||
@@ -292,6 +295,9 @@ class Scheduler(
|
||||
# Init tokenizer
|
||||
self.init_tokenizer()
|
||||
|
||||
# Init moe config
|
||||
self.init_moe_config()
|
||||
|
||||
# Set reasoning_parser and think_end_id if --reasoning_parser is enabled
|
||||
if self.server_args.reasoning_parser and self.tokenizer:
|
||||
reasoning_parser = ReasoningParser(
|
||||
@@ -538,8 +544,6 @@ class Scheduler(
|
||||
|
||||
def init_tokenizer(self):
|
||||
server_args = self.server_args
|
||||
|
||||
self.model_config = ModelConfig.from_server_args(server_args)
|
||||
self.is_generation = self.model_config.is_generation
|
||||
|
||||
if server_args.skip_tokenizer_init:
|
||||
@@ -761,6 +765,10 @@ class Scheduler(
|
||||
# The prefill requests that are in the middle of kv sending
|
||||
self.disagg_prefill_inflight_queue: List[Req] = []
|
||||
|
||||
def init_moe_config(self):
|
||||
if hasattr(self.model_config.hf_config, "num_experts_per_tok"):
|
||||
initialize_moe_config(self.server_args)
|
||||
|
||||
@DynamicGradMode()
|
||||
def event_loop_normal(self):
|
||||
"""A normal scheduler loop."""
|
||||
@@ -1823,11 +1831,6 @@ class Scheduler(
|
||||
disable_cuda_graph=self.server_args.disable_cuda_graph,
|
||||
spec_algorithm=self.spec_algorithm,
|
||||
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
|
||||
enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
|
||||
enable_deepep_moe=MoeA2ABackend(
|
||||
self.server_args.moe_a2a_backend
|
||||
).is_deepep(),
|
||||
deepep_mode=DeepEPMode(self.server_args.deepep_mode),
|
||||
require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
|
||||
disable_overlap_schedule=self.server_args.disable_overlap_schedule,
|
||||
)
|
||||
@@ -1922,9 +1925,6 @@ class Scheduler(
|
||||
disable_cuda_graph: bool,
|
||||
spec_algorithm,
|
||||
speculative_num_draft_tokens,
|
||||
enable_two_batch_overlap: bool,
|
||||
enable_deepep_moe: bool,
|
||||
deepep_mode: DeepEPMode,
|
||||
require_mlp_tp_gather: bool,
|
||||
disable_overlap_schedule: bool,
|
||||
):
|
||||
@@ -1972,9 +1972,6 @@ class Scheduler(
|
||||
is_extend_in_batch,
|
||||
*tbo_preparer.prepare_all_gather(
|
||||
local_batch,
|
||||
deepep_mode,
|
||||
enable_deepep_moe,
|
||||
enable_two_batch_overlap,
|
||||
),
|
||||
],
|
||||
dtype=torch.int64,
|
||||
|
||||
@@ -60,7 +60,6 @@ from sglang.srt.layers.dp_attention import (
|
||||
initialize_dp_attention,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
|
||||
from sglang.srt.layers.quantization import (
|
||||
deep_gemm_wrapper,
|
||||
monkey_patch_isinstance_for_vllm_base_layer,
|
||||
@@ -219,8 +218,6 @@ class ModelRunner:
|
||||
# TODO it is indeed not a "server args"
|
||||
"use_mla_backend": self.use_mla_backend,
|
||||
"speculative_algorithm": self.spec_algorithm,
|
||||
"moe_a2a_backend": MoeA2ABackend(server_args.moe_a2a_backend),
|
||||
"deepep_mode": DeepEPMode(server_args.deepep_mode),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -32,7 +32,9 @@ from sglang.srt.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope
|
||||
@@ -104,6 +106,11 @@ class DbrxExperts(nn.Module):
|
||||
self.params_dtype = params_dtype
|
||||
|
||||
self.router = DbrxRouter(config, self.params_dtype)
|
||||
self.topk = TopK(
|
||||
self.top_k,
|
||||
renormalize=True,
|
||||
)
|
||||
self.moe_runner_config = MoeRunnerConfig(inplace=True)
|
||||
self.ws = nn.Parameter(
|
||||
torch.empty(
|
||||
self.num_total_experts,
|
||||
@@ -169,14 +176,13 @@ class DbrxExperts(nn.Module):
|
||||
hidden_states = hidden_states.view(-1, self.d_model)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.router(hidden_states)
|
||||
topk_output = self.topk(hidden_states, router_logits)
|
||||
final_hidden_states = fused_moe(
|
||||
hidden_states,
|
||||
self.ws,
|
||||
self.w2s,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
renormalize=True,
|
||||
inplace=True,
|
||||
topk_output,
|
||||
self.moe_runner_config,
|
||||
)
|
||||
|
||||
if self.tp_size > 1:
|
||||
@@ -293,7 +299,7 @@ class DbrxFusedNormAttention(nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm_1(hidden_states)
|
||||
x = self.attn(
|
||||
|
||||
@@ -37,6 +37,7 @@ from sglang.srt.layers.linear import (
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
@@ -180,7 +181,7 @@ class DeepseekMoE(nn.Module):
|
||||
w1=self.w1,
|
||||
w2=self.w2,
|
||||
topk_output=topk_output,
|
||||
inplace=True,
|
||||
moe_runner_config=MoeRunnerConfig(inplace=True),
|
||||
)
|
||||
|
||||
if self.config.n_shared_experts is not None:
|
||||
|
||||
@@ -50,7 +50,6 @@ from sglang.srt.layers.communicator import (
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
get_local_attention_dp_size,
|
||||
is_dp_attention_enabled,
|
||||
)
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
@@ -61,9 +60,10 @@ from sglang.srt.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe import get_deepep_mode, get_moe_a2a_backend
|
||||
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
@@ -336,30 +336,6 @@ class DeepseekV2MoE(nn.Module):
|
||||
quant_config=quant_config,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
prefix=add_prefix("experts", prefix),
|
||||
**(
|
||||
dict(deepep_mode=global_server_args_dict["deepep_mode"])
|
||||
if global_server_args_dict["moe_a2a_backend"].is_deepep()
|
||||
else {}
|
||||
),
|
||||
# Additional args for FusedMoE
|
||||
**(
|
||||
dict(
|
||||
enable_flashinfer_cutlass_moe=True,
|
||||
)
|
||||
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
|
||||
else {}
|
||||
),
|
||||
**(
|
||||
dict(
|
||||
renormalize=config.norm_topk_prob,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
topk_group=config.topk_group,
|
||||
correction_bias=self.gate.e_score_correction_bias,
|
||||
)
|
||||
if should_use_flashinfer_trtllm_moe()
|
||||
else {}
|
||||
),
|
||||
)
|
||||
|
||||
self.shared_experts_is_int8 = False
|
||||
@@ -377,7 +353,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
prefix=add_prefix("shared_experts", prefix),
|
||||
**(
|
||||
dict(tp_rank=0, tp_size=1)
|
||||
if global_server_args_dict["moe_a2a_backend"].is_deepep()
|
||||
if get_moe_a2a_backend().is_deepep()
|
||||
else {}
|
||||
),
|
||||
)
|
||||
@@ -407,7 +383,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
|
||||
self.top_k = config.num_experts_per_tok
|
||||
|
||||
if global_server_args_dict["moe_a2a_backend"].is_deepep():
|
||||
if get_moe_a2a_backend().is_deepep():
|
||||
# TODO: we will support tp < ep in the future
|
||||
self.ep_size = get_moe_expert_parallel_world_size()
|
||||
self.num_experts = (
|
||||
@@ -431,12 +407,12 @@ class DeepseekV2MoE(nn.Module):
|
||||
num_local_experts=config.n_routed_experts // self.tp_size,
|
||||
hidden_size=config.hidden_size,
|
||||
params_dtype=config.torch_dtype,
|
||||
deepep_mode=global_server_args_dict["deepep_mode"],
|
||||
deepep_mode=get_deepep_mode(),
|
||||
async_finish=True,
|
||||
return_recv_hook=True,
|
||||
)
|
||||
|
||||
self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep()
|
||||
self._enable_deepep_moe = get_moe_a2a_backend().is_deepep()
|
||||
|
||||
def get_moe_weights(self):
|
||||
return [
|
||||
@@ -484,13 +460,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
kwargs = {"hidden_states": hidden_states}
|
||||
|
||||
# FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
|
||||
# Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
|
||||
if should_use_flashinfer_trtllm_moe():
|
||||
kwargs["topk_output"] = (self.topk, router_logits)
|
||||
else:
|
||||
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
||||
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
||||
|
||||
final_hidden_states = self.experts(**kwargs)
|
||||
if not _is_cuda:
|
||||
@@ -520,13 +490,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
kwargs = {"hidden_states": hidden_states}
|
||||
|
||||
# FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
|
||||
# Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
|
||||
if should_use_flashinfer_trtllm_moe():
|
||||
kwargs["topk_output"] = (self.topk, router_logits)
|
||||
else:
|
||||
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
||||
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
||||
|
||||
final_hidden_states = self.experts(**kwargs)
|
||||
if not _is_cuda and not _use_aiter:
|
||||
@@ -2478,17 +2442,15 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
|
||||
)
|
||||
if self.quant_config and self.quant_config.get_name() == "w4afp8":
|
||||
expert_params_mapping += (
|
||||
get_moe_impl_class().make_expert_input_scale_params_mapping(
|
||||
num_experts=self.config.n_routed_experts
|
||||
)
|
||||
expert_params_mapping += FusedMoE.make_expert_input_scale_params_mapping(
|
||||
num_experts=self.config.n_routed_experts
|
||||
)
|
||||
|
||||
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
|
||||
|
||||
@@ -31,13 +31,13 @@ from sglang.srt.layers.communicator import enable_moe_dense_fully_dp
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.deepseek_v2 import DeepseekV2MLP as Ernie4MLP
|
||||
@@ -361,7 +361,7 @@ class Ernie4_5_ForCausalLM(nn.Module):
|
||||
|
||||
class Ernie4_5_MoeForCausalLM(Ernie4_5_ForCausalLM):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
|
||||
@@ -39,7 +39,6 @@ from sglang.srt.layers.communicator import (
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
get_local_attention_dp_size,
|
||||
is_dp_attention_enabled,
|
||||
)
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
@@ -51,9 +50,10 @@ from sglang.srt.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe import get_deepep_mode, get_moe_a2a_backend
|
||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
is_fp8_fnuz,
|
||||
@@ -76,10 +76,7 @@ from sglang.srt.models.deepseek_v2 import (
|
||||
DeepseekV2Model,
|
||||
DeepseekV2MoE,
|
||||
)
|
||||
from sglang.srt.two_batch_overlap import (
|
||||
MaybeTboDeepEPDispatcher,
|
||||
model_forward_maybe_tbo,
|
||||
)
|
||||
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
|
||||
from sglang.srt.utils import (
|
||||
BumpAllocator,
|
||||
LazyValue,
|
||||
@@ -414,19 +411,15 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
||||
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
|
||||
)
|
||||
|
||||
self.topk = (
|
||||
TopK(
|
||||
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
||||
renormalize=config.norm_topk_prob,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
topk_group=config.topk_group,
|
||||
correction_bias=self.gate.e_score_correction_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
)
|
||||
if not should_use_flashinfer_trtllm_moe()
|
||||
else None
|
||||
self.topk = TopK(
|
||||
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
||||
renormalize=config.norm_topk_prob,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
topk_group=config.topk_group,
|
||||
correction_bias=self.gate.e_score_correction_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
)
|
||||
|
||||
self.experts = get_moe_impl_class()(
|
||||
@@ -441,31 +434,6 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
||||
quant_config=quant_config,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
prefix=add_prefix("experts", prefix),
|
||||
**(
|
||||
dict(deepep_mode=global_server_args_dict["deepep_mode"])
|
||||
if global_server_args_dict["moe_a2a_backend"].is_deepep()
|
||||
else {}
|
||||
),
|
||||
# Additional args for FusedMoE
|
||||
**(
|
||||
dict(
|
||||
enable_flashinfer_cutlass_moe=True,
|
||||
)
|
||||
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
|
||||
else {}
|
||||
),
|
||||
**(
|
||||
dict(
|
||||
renormalize=config.norm_topk_prob,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
topk_group=config.topk_group,
|
||||
correction_bias=self.gate.e_score_correction_bias,
|
||||
)
|
||||
if should_use_flashinfer_trtllm_moe()
|
||||
else {}
|
||||
),
|
||||
)
|
||||
|
||||
self.shared_experts_is_int8 = False
|
||||
@@ -496,7 +464,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
||||
|
||||
self.top_k = config.num_experts_per_tok
|
||||
|
||||
if global_server_args_dict["moe_a2a_backend"].is_deepep():
|
||||
if get_moe_a2a_backend().is_deepep():
|
||||
# TODO: we will support tp < ep in the future
|
||||
self.ep_size = get_moe_expert_parallel_world_size()
|
||||
self.num_experts = (
|
||||
@@ -520,12 +488,12 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
||||
num_local_experts=config.n_routed_experts // self.tp_size,
|
||||
hidden_size=config.hidden_size,
|
||||
params_dtype=config.torch_dtype,
|
||||
deepep_mode=global_server_args_dict["deepep_mode"],
|
||||
deepep_mode=get_deepep_mode(),
|
||||
async_finish=True,
|
||||
return_recv_hook=True,
|
||||
)
|
||||
|
||||
self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep()
|
||||
self._enable_deepep_moe = get_moe_a2a_backend().is_deepep()
|
||||
|
||||
def forward_normal_dual_stream(
|
||||
self,
|
||||
@@ -542,10 +510,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
kwargs = {"hidden_states": hidden_states}
|
||||
if self.topk is not None:
|
||||
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
||||
else:
|
||||
kwargs["router_logits"] = router_logits
|
||||
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
||||
final_hidden_states = self.experts(**kwargs)
|
||||
if not _is_cuda:
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
@@ -588,10 +553,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
kwargs = {"hidden_states": hidden_states}
|
||||
if self.topk is not None:
|
||||
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
||||
else:
|
||||
kwargs["router_logits"] = router_logits
|
||||
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
||||
final_hidden_states = self.experts(**kwargs)
|
||||
if not _is_cuda and not _use_aiter:
|
||||
# fused in biased_grouped_topk so we can skip here
|
||||
@@ -761,8 +723,6 @@ class Glm4MoeModel(DeepseekV2Model):
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.dp_size = get_local_attention_dp_size()
|
||||
|
||||
|
||||
class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
|
||||
|
||||
@@ -789,7 +749,6 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
|
||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.dp_size = get_local_attention_dp_size()
|
||||
|
||||
self._routed_experts_weights_of_layer = LazyValue(
|
||||
lambda: {
|
||||
@@ -953,7 +912,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
|
||||
@@ -8,19 +8,11 @@ from transformers.models.glm4v_moe.configuration_glm4v_moe import Glm4vMoeConfig
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
get_moe_expert_parallel_world_size,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
parallel_state,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.hf_transformers_utils import get_processor
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
get_local_attention_dp_size,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
@@ -49,7 +41,6 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
|
||||
config.moe_layer_freq = 1
|
||||
self.config = config
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.dp_size = get_local_attention_dp_size()
|
||||
self.quant_config = quant_config
|
||||
self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
|
||||
self.num_fused_shared_experts = (
|
||||
@@ -232,7 +223,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
|
||||
@@ -40,7 +40,6 @@ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
get_local_attention_dp_size,
|
||||
is_dp_attention_enabled,
|
||||
)
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
@@ -50,9 +49,10 @@ from sglang.srt.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe import get_moe_a2a_backend
|
||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.srt.layers.moe.utils import DeepEPMode
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
@@ -110,16 +110,13 @@ class GptOssSparseMoeBlock(nn.Module):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.layer_id = layer_id
|
||||
self.activation = config.hidden_act
|
||||
self.activation_alpha = getattr(config, "hidden_act_alpha", 1.702)
|
||||
self.swiglu_limit = config.swiglu_limit
|
||||
self.gemm1_alpha = getattr(config, "hidden_act_alpha", 1.702)
|
||||
self.gemm1_clamp_limit = config.swiglu_limit
|
||||
|
||||
if global_server_args_dict["enable_flashinfer_mxfp4_moe"]:
|
||||
self.topk = None
|
||||
else:
|
||||
self.topk = TopK(
|
||||
top_k=config.num_experts_per_tok,
|
||||
renormalize=True,
|
||||
)
|
||||
self.topk = TopK(
|
||||
top_k=config.num_experts_per_tok,
|
||||
renormalize=True,
|
||||
)
|
||||
|
||||
self.top_k = config.num_experts_per_tok
|
||||
experts_type = get_moe_impl_class()
|
||||
@@ -129,11 +126,9 @@ class GptOssSparseMoeBlock(nn.Module):
|
||||
quant_config.get_name() if quant_config is not None else None
|
||||
)
|
||||
extra_kwargs = {
|
||||
"enable_flashinfer_cutlass_moe": global_server_args_dict[
|
||||
"enable_flashinfer_cutlass_moe"
|
||||
],
|
||||
# for moe gate_up_proj and down_proj and their bias loading
|
||||
"use_weight_loader_fused": quant_config_name != "mxfp4",
|
||||
"use_weight_loader_fused": quant_config_name
|
||||
!= "mxfp4"
|
||||
}
|
||||
self.experts = experts_type(
|
||||
num_experts=config.num_local_experts
|
||||
@@ -144,15 +139,10 @@ class GptOssSparseMoeBlock(nn.Module):
|
||||
intermediate_size=config.intermediate_size,
|
||||
quant_config=quant_config,
|
||||
activation=self.activation,
|
||||
activation_alpha=self.activation_alpha,
|
||||
swiglu_limit=self.swiglu_limit,
|
||||
gemm1_alpha=self.gemm1_alpha,
|
||||
gemm1_clamp_limit=self.gemm1_clamp_limit,
|
||||
with_bias=True,
|
||||
prefix=add_prefix("experts", prefix),
|
||||
**(
|
||||
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
|
||||
if global_server_args_dict["moe_a2a_backend"].is_deepep()
|
||||
else {}
|
||||
),
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
@@ -171,7 +161,7 @@ class GptOssSparseMoeBlock(nn.Module):
|
||||
forward_batch: Optional[ForwardBatch] = None,
|
||||
should_allreduce_fusion: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if not global_server_args_dict["moe_a2a_backend"].is_deepep():
|
||||
if not get_moe_a2a_backend().is_deepep():
|
||||
return self.forward_normal(hidden_states, should_allreduce_fusion)
|
||||
else:
|
||||
raise Exception("forward_deepep branch not implemented yet")
|
||||
@@ -189,17 +179,10 @@ class GptOssSparseMoeBlock(nn.Module):
|
||||
should_allreduce_fusion: bool = False,
|
||||
) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.router(hidden_states)
|
||||
|
||||
kwargs = {"hidden_states": hidden_states}
|
||||
if self.topk is not None:
|
||||
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
||||
else:
|
||||
kwargs["topk_output"] = (self.top_k, router_logits)
|
||||
final_hidden_states = self.experts(**kwargs)
|
||||
topk_output = self.topk(hidden_states, router_logits)
|
||||
final_hidden_states = self.experts(hidden_states, topk_output)
|
||||
|
||||
if self.tp_size > 1 and not should_allreduce_fusion:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
@@ -436,7 +419,6 @@ class GptOssDecoderLayer(nn.Module):
|
||||
|
||||
self.attn_tp_size = get_attention_tp_size()
|
||||
self.attn_tp_rank = get_attention_tp_rank()
|
||||
self.local_dp_size = get_local_attention_dp_size()
|
||||
|
||||
# GptOss all layers are sparse and have no nextn now
|
||||
self.is_layer_sparse = True
|
||||
@@ -1060,7 +1042,7 @@ class GptOssForCausalLM(nn.Module):
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping_fused(
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping_fused(
|
||||
ckpt_gate_up_proj_name="gate_up_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_gate_up_proj_bias_name="gate_up_proj_bias",
|
||||
|
||||
@@ -76,7 +76,6 @@ class GraniteMoeMoE(nn.Module):
|
||||
params_dtype=params_dtype,
|
||||
reduce_results=True,
|
||||
quant_config=quant_config,
|
||||
tp_size=tp_size,
|
||||
prefix=f"{prefix}.experts",
|
||||
)
|
||||
|
||||
|
||||
@@ -135,7 +135,6 @@ class Grok1MoE(nn.Module):
|
||||
intermediate_size=intermediate_size,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
tp_size=tp_size,
|
||||
activation="gelu",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.distributed import parallel_state
|
||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternTokenPairs,
|
||||
@@ -254,7 +255,7 @@ class InternS1ForConditionalGeneration(nn.Module):
|
||||
]
|
||||
expert_params_mapping = []
|
||||
if "Qwen3MoeForCausalLM" in self.config.text_config.architectures:
|
||||
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
|
||||
@@ -12,7 +12,7 @@ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPo
|
||||
|
||||
from sglang.srt.distributed import parallel_state
|
||||
from sglang.srt.layers.attention.vision import SingletonCache, VisionAttention
|
||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternTokenPairs,
|
||||
@@ -616,7 +616,7 @@ class InternVLChatModel(nn.Module):
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
|
||||
@@ -31,7 +31,6 @@ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
get_local_attention_dp_size,
|
||||
is_dp_attention_enabled,
|
||||
)
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
@@ -364,7 +363,6 @@ class Llama4DecoderLayer(nn.Module):
|
||||
rope_theta = config.rope_theta
|
||||
rope_scaling = config.rope_scaling
|
||||
max_position_embeddings = config.max_position_embeddings
|
||||
self.local_dp_size = get_local_attention_dp_size()
|
||||
self.attn_tp_size = get_attention_tp_size()
|
||||
self.attn_tp_rank = get_attention_tp_rank()
|
||||
|
||||
|
||||
@@ -37,7 +37,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix, is_cuda
|
||||
|
||||
@@ -47,7 +47,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix, make_layers
|
||||
@@ -104,7 +103,6 @@ class MixtralMoE(nn.Module):
|
||||
intermediate_size=intermediate_size,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
tp_size=tp_size,
|
||||
prefix=add_prefix("experts", prefix),
|
||||
)
|
||||
|
||||
|
||||
@@ -89,7 +89,6 @@ class OlmoeMoE(nn.Module):
|
||||
intermediate_size=intermediate_size,
|
||||
reduce_results=True,
|
||||
quant_config=quant_config,
|
||||
tp_size=tp_size,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("experts", prefix),
|
||||
)
|
||||
|
||||
@@ -17,8 +17,6 @@
|
||||
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -31,10 +29,7 @@ from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.eplb.expert_distribution import (
|
||||
ExpertDistributionRecorder,
|
||||
get_global_expert_distribution_recorder,
|
||||
)
|
||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.communicator import (
|
||||
@@ -45,7 +40,6 @@ from sglang.srt.layers.communicator import (
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
get_local_attention_dp_size,
|
||||
is_dp_attention_enabled,
|
||||
)
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
@@ -55,8 +49,8 @@ from sglang.srt.layers.linear import (
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE, get_moe_impl_class
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
@@ -149,14 +143,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("experts", prefix),
|
||||
# Additional args for FusedMoE
|
||||
**(
|
||||
dict(
|
||||
enable_flashinfer_cutlass_moe=True,
|
||||
)
|
||||
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
|
||||
else {}
|
||||
),
|
||||
)
|
||||
|
||||
self.gate = ReplicatedLinear(
|
||||
@@ -340,7 +326,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
|
||||
self.attn_tp_size = get_attention_tp_size()
|
||||
self.attn_tp_rank = get_attention_tp_rank()
|
||||
self.local_dp_size = get_local_attention_dp_size()
|
||||
|
||||
# Qwen2MoE all layers are sparse and have no nextn now
|
||||
self.is_layer_sparse = True
|
||||
|
||||
@@ -28,50 +28,35 @@ from sglang.srt.distributed import (
|
||||
get_pp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
parallel_state,
|
||||
split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
||||
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
get_local_attention_dp_size,
|
||||
)
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe import get_moe_a2a_backend
|
||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope
|
||||
from sglang.srt.layers.utils import get_layer_id
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
PPProxyTensors,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
|
||||
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
|
||||
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
|
||||
from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty
|
||||
|
||||
Qwen3MoeConfig = None
|
||||
@@ -112,19 +97,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("experts", prefix),
|
||||
**(
|
||||
dict(deepep_mode=global_server_args_dict["deepep_mode"])
|
||||
if global_server_args_dict["moe_a2a_backend"].is_deepep()
|
||||
else {}
|
||||
),
|
||||
# Additional args for FusedMoE
|
||||
**(
|
||||
dict(
|
||||
enable_flashinfer_cutlass_moe=True,
|
||||
)
|
||||
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
|
||||
else {}
|
||||
),
|
||||
)
|
||||
|
||||
self.gate = ReplicatedLinear(
|
||||
@@ -135,7 +107,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
prefix=add_prefix("gate", prefix),
|
||||
)
|
||||
|
||||
if global_server_args_dict["moe_a2a_backend"].is_deepep():
|
||||
if get_moe_a2a_backend().is_deepep():
|
||||
# TODO: we will support tp < ep in the future
|
||||
self.ep_size = get_moe_expert_parallel_world_size()
|
||||
self.num_experts = (
|
||||
@@ -150,7 +122,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
use_reduce_scatter: bool = False,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if not global_server_args_dict["moe_a2a_backend"].is_deepep():
|
||||
if not get_moe_a2a_backend().is_deepep():
|
||||
return self.forward_normal(hidden_states, use_reduce_scatter)
|
||||
else:
|
||||
return self.forward_deepep(hidden_states, forward_batch)
|
||||
@@ -491,7 +463,6 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
||||
|
||||
self.attn_tp_size = get_attention_tp_size()
|
||||
self.attn_tp_rank = get_attention_tp_rank()
|
||||
self.local_dp_size = get_local_attention_dp_size()
|
||||
|
||||
# Qwen3MoE all layers are sparse and have no nextn now
|
||||
self.is_layer_sparse = True
|
||||
@@ -778,7 +749,7 @@ class Qwen3MoeForCausalLM(nn.Module):
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
|
||||
@@ -38,6 +38,7 @@ from sglang.srt.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe import get_moe_a2a_backend
|
||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
@@ -150,7 +151,7 @@ class Step3TextMoEMLP(nn.Module):
|
||||
prefix=add_prefix("gate", prefix),
|
||||
)
|
||||
|
||||
if global_server_args_dict["moe_a2a_backend"].is_deepep():
|
||||
if get_moe_a2a_backend().is_deepep():
|
||||
raise NotImplementedError("DeepEP MoE is not supported yet in Step3 model.")
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@@ -33,7 +33,9 @@ from sglang.srt.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope
|
||||
@@ -121,6 +123,7 @@ class XverseMoE(nn.Module):
|
||||
]
|
||||
)
|
||||
self.pack_params()
|
||||
self.moe_runner_config = MoeRunnerConfig(inplace=True)
|
||||
|
||||
self.router = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
@@ -129,6 +132,10 @@ class XverseMoE(nn.Module):
|
||||
quant_config=None,
|
||||
prefix=add_prefix("router", prefix),
|
||||
)
|
||||
self.topk = TopK(
|
||||
top_k=self.top_k,
|
||||
renormalize=getattr(self.config, "norm_topk_prob", False),
|
||||
)
|
||||
|
||||
if config.num_shared_experts is not None:
|
||||
intermediate_size = config.intermediate_size * config.num_shared_experts
|
||||
@@ -167,14 +174,13 @@ class XverseMoE(nn.Module):
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.router(hidden_states)
|
||||
topk_output = self.topk(hidden_states, router_logits)
|
||||
final_hidden_states = fused_moe(
|
||||
hidden_states,
|
||||
self.w1,
|
||||
self.w2,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
renormalize=getattr(self.config, "norm_topk_prob", False),
|
||||
inplace=True,
|
||||
topk_output,
|
||||
self.moe_runner_config,
|
||||
)
|
||||
|
||||
if self.config.num_shared_experts is not None:
|
||||
|
||||
@@ -37,6 +37,7 @@ from sglang.srt.utils import (
|
||||
is_hip,
|
||||
is_port_available,
|
||||
is_remote_url,
|
||||
is_triton_kernels_available,
|
||||
is_valid_ipv6_address,
|
||||
nullable_str,
|
||||
)
|
||||
@@ -175,9 +176,15 @@ class ServerArgs:
|
||||
|
||||
# Expert parallelism
|
||||
ep_size: int = 1
|
||||
moe_a2a_backend: Optional[Literal["deepep"]] = None
|
||||
enable_flashinfer_cutlass_moe: bool = False
|
||||
enable_flashinfer_trtllm_moe: bool = False
|
||||
moe_a2a_backend: Literal["none", "deepep"] = "none"
|
||||
moe_runner_backend: Literal[
|
||||
"auto",
|
||||
"triton",
|
||||
"triton_kernel",
|
||||
"flashinfer_trtllm",
|
||||
"flashinfer_cutlass",
|
||||
"flashinfer_mxfp4",
|
||||
] = "auto"
|
||||
enable_flashinfer_allreduce_fusion: bool = False
|
||||
deepep_mode: Literal["auto", "normal", "low_latency"] = "auto"
|
||||
ep_num_redundant_experts: int = 0
|
||||
@@ -250,8 +257,6 @@ class ServerArgs:
|
||||
disable_chunked_prefix_cache: bool = False
|
||||
disable_fast_image_processor: bool = False
|
||||
enable_return_hidden_states: bool = False
|
||||
enable_triton_kernel_moe: bool = False
|
||||
enable_flashinfer_mxfp4_moe: bool = False
|
||||
scheduler_recv_interval: int = 1
|
||||
|
||||
# Debug tensor dumps
|
||||
@@ -282,6 +287,9 @@ class ServerArgs:
|
||||
# Deprecated arguments
|
||||
enable_ep_moe: bool = False
|
||||
enable_deepep_moe: bool = False
|
||||
enable_flashinfer_cutlass_moe: bool = False
|
||||
enable_flashinfer_trtllm_moe: bool = False
|
||||
enable_triton_kernel_moe: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# Check deprecated arguments
|
||||
@@ -298,6 +306,21 @@ class ServerArgs:
|
||||
print_deprecated_warning(
|
||||
"NOTE: --enable-deepep-moe is deprecated. Please set `--moe-a2a-backend` to 'deepep' instead."
|
||||
)
|
||||
if self.enable_triton_kernel_moe:
|
||||
self.moe_runner_backend = "triton_kernel"
|
||||
print_deprecated_warning(
|
||||
"NOTE: --enable-triton-kernel-moe is deprecated. Please set `--moe-runner-backend` to 'triton_kernel' instead."
|
||||
)
|
||||
if self.enable_flashinfer_cutlass_moe:
|
||||
self.moe_runner_backend = "flashinfer_cutlass"
|
||||
print_deprecated_warning(
|
||||
"NOTE: --enable-flashinfer-cutlass-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutlass' instead."
|
||||
)
|
||||
if self.enable_flashinfer_trtllm_moe:
|
||||
self.moe_runner_backend = "flashinfer_trtllm"
|
||||
print_deprecated_warning(
|
||||
"NOTE: --enable-flashinfer-trtllm-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_trtllm' instead."
|
||||
)
|
||||
|
||||
# Set missing default values
|
||||
if self.tokenizer_path is None:
|
||||
@@ -517,7 +540,7 @@ class ServerArgs:
|
||||
), "Please enable dp attention when setting enable_dp_lm_head. "
|
||||
|
||||
# MoE kernel
|
||||
if self.enable_flashinfer_cutlass_moe:
|
||||
if self.moe_runner_backend == "flashinfer_cutlass":
|
||||
assert (
|
||||
self.quantization == "modelopt_fp4"
|
||||
), "modelopt_fp4 quantization is required for Flashinfer MOE"
|
||||
@@ -527,7 +550,7 @@ class ServerArgs:
|
||||
self.tp_size,
|
||||
], "The expert parallel size must be 1 or the same as the tensor parallel size"
|
||||
|
||||
if self.enable_flashinfer_trtllm_moe:
|
||||
if self.moe_runner_backend == "flashinfer_trtllm":
|
||||
if not self.disable_shared_experts_fusion:
|
||||
self.disable_shared_experts_fusion = True
|
||||
logger.warning(
|
||||
@@ -556,7 +579,7 @@ class ServerArgs:
|
||||
self.ep_dispatch_algorithm = "static"
|
||||
|
||||
if self.enable_eplb:
|
||||
assert self.ep_size > 1 or self.moe_a2a_backend is not None
|
||||
assert self.ep_size > 1
|
||||
|
||||
if self.enable_expert_distribution_metrics and (
|
||||
self.expert_distribution_recorder_mode is None
|
||||
@@ -1446,19 +1469,22 @@ class ServerArgs:
|
||||
parser.add_argument(
|
||||
"--moe-a2a-backend",
|
||||
type=str,
|
||||
choices=["deepep"],
|
||||
choices=["none", "deepep"],
|
||||
default=ServerArgs.moe_a2a_backend,
|
||||
help="Choose the backend for MoE A2A.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-flashinfer-cutlass-moe",
|
||||
action="store_true",
|
||||
help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-flashinfer-trtllm-moe",
|
||||
action="store_true",
|
||||
help="Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP",
|
||||
"--moe-runner-backend",
|
||||
type=str,
|
||||
choices=[
|
||||
"auto",
|
||||
"triton",
|
||||
"triton_kernel",
|
||||
"flashinfer_trtllm",
|
||||
"flashinfer_cutlass",
|
||||
],
|
||||
default=ServerArgs.moe_runner_backend,
|
||||
help="Choose the runner backend for MoE.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-flashinfer-allreduce-fusion",
|
||||
@@ -1825,11 +1851,6 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Enable returning hidden states with responses.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-triton-kernel-moe",
|
||||
action="store_true",
|
||||
help="Use triton moe grouped gemm kernel.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-flashinfer-mxfp4-moe",
|
||||
action="store_true",
|
||||
@@ -1965,6 +1986,21 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="(Deprecated) Enabling DeepEP MoE implementation for EP MoE.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-flashinfer-cutlass-moe",
|
||||
action="store_true",
|
||||
help="(Deprecated) Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-flashinfer-trtllm-moe",
|
||||
action="store_true",
|
||||
help="(Deprecated) Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-triton-kernel-moe",
|
||||
action="store_true",
|
||||
help="(Deprecated) Use triton moe grouped gemm kernel.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
@@ -2143,18 +2179,21 @@ class ServerArgs:
|
||||
)
|
||||
|
||||
if is_sm100_supported() and is_mxfp4_quant_format:
|
||||
self.enable_flashinfer_mxfp4_moe = True
|
||||
self.enable_triton_kernel_moe = False
|
||||
self.moe_runner_backend = "flashinfer_mxfp4"
|
||||
logger.warning(
|
||||
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
|
||||
)
|
||||
else:
|
||||
if self.enable_triton_kernel_moe:
|
||||
if self.moe_runner_backend == "triton_kernel":
|
||||
assert (
|
||||
self.ep_size == 1
|
||||
), "Triton kernel MoE is only supported when ep_size == 1"
|
||||
if not self.enable_triton_kernel_moe and self.ep_size == 1:
|
||||
self.enable_triton_kernel_moe = True
|
||||
if (
|
||||
self.moe_runner_backend == "auto"
|
||||
and self.ep_size == 1
|
||||
and is_triton_kernels_available()
|
||||
):
|
||||
self.moe_runner_backend = "triton_kernel"
|
||||
logger.warning(
|
||||
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
|
||||
)
|
||||
|
||||
@@ -14,8 +14,13 @@ from sglang.srt.layers.communicator import (
|
||||
CommunicateSummableTensorPairFn,
|
||||
ScatterMode,
|
||||
)
|
||||
from sglang.srt.layers.moe import (
|
||||
get_deepep_mode,
|
||||
get_moe_a2a_backend,
|
||||
get_tbo_token_distribution_threshold,
|
||||
is_tbo_enabled,
|
||||
)
|
||||
from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
|
||||
from sglang.srt.layers.moe.utils import DeepEPMode
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
@@ -83,7 +88,7 @@ def _is_two_chunk_split_enabled(extend_lens: Sequence[int]) -> bool:
|
||||
vanilla_split_seq_index = _split_array_by_balanced_sum(extend_lens)
|
||||
left_sum = sum(extend_lens[:vanilla_split_seq_index])
|
||||
overall_sum = sum(extend_lens)
|
||||
threshold = global_server_args_dict["tbo_token_distribution_threshold"]
|
||||
threshold = get_tbo_token_distribution_threshold()
|
||||
assert threshold <= 0.5, f"{threshold=}"
|
||||
return left_sum < overall_sum * threshold or left_sum > overall_sum * (
|
||||
1 - threshold
|
||||
@@ -299,7 +304,7 @@ class TboCudaGraphRunnerPlugin:
|
||||
self._tbo_children_num_token_non_padded = torch.zeros((2,), dtype=torch.int32)
|
||||
|
||||
def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int):
|
||||
if not global_server_args_dict["enable_two_batch_overlap"]:
|
||||
if not is_tbo_enabled():
|
||||
return
|
||||
token_num_per_seq = get_token_num_per_seq(
|
||||
forward_mode=batch.forward_mode, spec_info=batch.spec_info
|
||||
@@ -353,10 +358,12 @@ class TboDPAttentionPreparer:
|
||||
def prepare_all_gather(
|
||||
self,
|
||||
local_batch: ScheduleBatch,
|
||||
deepep_mode: DeepEPMode,
|
||||
enable_deepep_moe: bool,
|
||||
enable_two_batch_overlap: bool,
|
||||
):
|
||||
|
||||
deepep_mode = get_deepep_mode()
|
||||
enable_deepep_moe = get_moe_a2a_backend().is_deepep()
|
||||
enable_two_batch_overlap = is_tbo_enabled()
|
||||
|
||||
self.enable_two_batch_overlap = enable_two_batch_overlap
|
||||
|
||||
if local_batch is not None:
|
||||
@@ -384,7 +391,7 @@ class TboDPAttentionPreparer:
|
||||
and not local_batch.forward_mode.is_target_verify()
|
||||
)
|
||||
and enable_deepep_moe
|
||||
and (resolved_deepep_mode == DeepEPMode.LOW_LATENCY)
|
||||
and (resolved_deepep_mode.is_low_latency())
|
||||
)
|
||||
else:
|
||||
self.local_tbo_split_seq_index = 0
|
||||
@@ -657,6 +664,7 @@ class TboForwardBatchPreparer:
|
||||
"req_to_token_pool",
|
||||
"token_to_kv_pool",
|
||||
"can_run_dp_cuda_graph",
|
||||
"dp_padding_mode",
|
||||
"global_forward_mode",
|
||||
"spec_algorithm",
|
||||
"capture_hidden_mode",
|
||||
@@ -701,7 +709,6 @@ class TboForwardBatchPreparer:
|
||||
tbo_children=None,
|
||||
global_num_tokens_gpu=None,
|
||||
global_num_tokens_cpu=None,
|
||||
dp_padding_mode=None,
|
||||
global_dp_buffer_len=global_dp_buffer_len,
|
||||
global_num_tokens_for_logprob_gpu=None,
|
||||
global_num_tokens_for_logprob_cpu=None,
|
||||
@@ -955,9 +962,7 @@ def _model_forward_tbo_merge_outputs(output_a, output_b):
|
||||
|
||||
class MaybeTboDeepEPDispatcher:
|
||||
def __init__(self, **kwargs):
|
||||
num_inner_dispatchers = (
|
||||
2 if global_server_args_dict["enable_two_batch_overlap"] else 1
|
||||
)
|
||||
num_inner_dispatchers = 2 if is_tbo_enabled() else 1
|
||||
self._inners = [
|
||||
DeepEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)
|
||||
]
|
||||
|
||||
@@ -2413,7 +2413,7 @@ def require_mlp_tp_gather(server_args):
|
||||
return True
|
||||
elif not server_args.enable_dp_lm_head:
|
||||
return True
|
||||
elif server_args.moe_a2a_backend is None:
|
||||
elif server_args.moe_a2a_backend == "none":
|
||||
return True
|
||||
else:
|
||||
return (
|
||||
@@ -2429,7 +2429,7 @@ def require_attn_tp_gather(server_args):
|
||||
Check if the input of attention is scattered.
|
||||
"""
|
||||
assert server_args.moe_dense_tp_size in [1, None]
|
||||
if server_args.moe_a2a_backend is not None or server_args.moe_dense_tp_size == 1:
|
||||
if server_args.moe_a2a_backend != "none" or server_args.moe_dense_tp_size == 1:
|
||||
if server_args.enable_dp_attention:
|
||||
return server_args.dp_size < server_args.tp_size
|
||||
else:
|
||||
|
||||
@@ -6,7 +6,7 @@ import torch
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
per_tensor_quant_mla_fp8,
|
||||
per_token_group_quant_fp8,
|
||||
@@ -498,11 +498,13 @@ class TestW8A8BlockFP8FusedMoE(CustomTestCase):
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
with torch.inference_mode():
|
||||
ref_out = torch_w8a8_block_fp8_moe(
|
||||
a, w1, w2, w1_s, w2_s, score, topk, block_size
|
||||
)
|
||||
topk_output = select_experts(
|
||||
hidden_states=a,
|
||||
router_logits=score,
|
||||
top_k=topk,
|
||||
renormalize=False,
|
||||
topk_config=TopKConfig(top_k=topk, renormalize=False),
|
||||
)
|
||||
out = fused_moe(
|
||||
a,
|
||||
@@ -514,9 +516,6 @@ class TestW8A8BlockFP8FusedMoE(CustomTestCase):
|
||||
w2_scale=w2_s,
|
||||
block_shape=block_size,
|
||||
)
|
||||
ref_out = torch_w8a8_block_fp8_moe(
|
||||
a, w1, w2, w1_s, w2_s, score, topk, block_size
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
|
||||
|
||||
@@ -12,7 +12,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
run_moe_ep_preproess,
|
||||
silu_and_mul_triton_kernel,
|
||||
)
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
@@ -22,35 +22,26 @@ def ep_moe(
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
topk_config: TopKConfig,
|
||||
# ep config
|
||||
num_experts: int = 256,
|
||||
fp8_dtype: torch.types = torch.float8_e4m3fn,
|
||||
num_experts_per_partition: int = 128,
|
||||
start_expert_id: int = 0,
|
||||
end_expert_id: int = 127,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
use_fp8_w8a8: bool = False,
|
||||
w1_scale_inv: Optional[torch.Tensor] = None,
|
||||
w2_scale_inv: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
):
|
||||
use_blockwise_fp8 = block_shape is not None
|
||||
topk_weights, topk_ids, _ = select_experts(
|
||||
top_k = topk_config.top_k
|
||||
topk_output = select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
# correction_bias=correction_bias, #skip this in test
|
||||
custom_routing_function=custom_routing_function,
|
||||
topk_config=topk_config,
|
||||
)
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
|
||||
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(topk_ids, num_experts)
|
||||
|
||||
@@ -294,14 +285,18 @@ class TestW8A8BlockFP8EPMoE(CustomTestCase):
|
||||
start_id = cur_rank * num_experts_per_partition
|
||||
end_id = start_id + num_experts_per_partition - 1
|
||||
|
||||
topk_config = TopKConfig(
|
||||
top_k=topk,
|
||||
renormalize=False,
|
||||
)
|
||||
|
||||
with torch.inference_mode():
|
||||
out = ep_moe(
|
||||
hidden_states=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
router_logits=score,
|
||||
top_k=topk,
|
||||
renormalize=False,
|
||||
topk_config=topk_config,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale_inv=w1_s,
|
||||
w2_scale_inv=w2_s,
|
||||
@@ -316,8 +311,7 @@ class TestW8A8BlockFP8EPMoE(CustomTestCase):
|
||||
w1=w1_ref,
|
||||
w2=w2_ref,
|
||||
router_logits=score,
|
||||
top_k=topk,
|
||||
renormalize=False,
|
||||
topk_config=topk_config,
|
||||
use_fp8_w8a8=False,
|
||||
w1_scale_inv=None,
|
||||
w2_scale_inv=None,
|
||||
|
||||
@@ -6,7 +6,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
|
||||
|
||||
|
||||
def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Tensor:
|
||||
@@ -100,11 +100,12 @@ def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype):
|
||||
s_strides2 = c_strides2
|
||||
|
||||
score = torch.randn((M, E), dtype=dtype, device=device)
|
||||
topk_weights, topk_ids, _ = select_experts(
|
||||
topk_output = select_experts(
|
||||
hidden_states=a,
|
||||
router_logits=score,
|
||||
top_k=topk,
|
||||
topk_config=TopKConfig(top_k=topk, renormalize=False),
|
||||
)
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
expert_map = torch.arange(E, dtype=torch.int32, device=device)
|
||||
expert_map[local_e:] = E
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from sgl_kernel import scaled_fp4_quant
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
||||
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
|
||||
|
||||
if torch.cuda.get_device_capability() < (10, 0):
|
||||
pytest.skip(
|
||||
@@ -163,11 +163,12 @@ def check_moe(
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
topk_weights, topk_ids, _ = select_experts(
|
||||
topk_output = select_experts(
|
||||
hidden_states=a,
|
||||
router_logits=score,
|
||||
top_k=topk,
|
||||
topk_config=TopKConfig(top_k=topk, renormalize=False),
|
||||
)
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
|
||||
a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
|
||||
a2_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
|
||||
|
||||
Reference in New Issue
Block a user