From 3fa62da78c120be7103bfcb6fd1405d3630d6c98 Mon Sep 17 00:00:00 2001 From: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Date: Fri, 5 Sep 2025 21:09:09 -0700 Subject: [PATCH] [7/N] MoE Refactor: the implementation of new framework (#9269) --- python/sglang/srt/eplb/expert_distribution.py | 23 +- python/sglang/srt/eplb/expert_location.py | 11 +- python/sglang/srt/layers/moe/__init__.py | 3 +- .../sglang/srt/layers/moe/fused_moe_native.py | 8 +- .../layers/moe/fused_moe_triton/fused_moe.py | 7 +- .../srt/layers/moe/fused_moe_triton/layer.py | 77 +-- .../srt/layers/moe/moe_runner/__init__.py | 3 +- .../sglang/srt/layers/moe/moe_runner/base.py | 285 ++++++++++- .../srt/layers/moe/moe_runner/runner.py | 84 ++++ .../srt/layers/moe/moe_runner/triton.py | 442 ++++++++++++++++++ .../layers/moe/token_dispatcher/__init__.py | 18 +- .../{base_dispatcher.py => base.py} | 75 ++- .../srt/layers/moe/token_dispatcher/deepep.py | 31 +- .../layers/moe/token_dispatcher/standard.py | 46 +- python/sglang/srt/layers/moe/utils.py | 10 +- python/sglang/srt/layers/quantization/awq.py | 26 +- .../srt/layers/quantization/base_config.py | 17 +- .../srt/layers/quantization/blockwise_int8.py | 63 +-- .../compressed_tensors_moe.py | 80 ++-- python/sglang/srt/layers/quantization/fp8.py | 106 +++-- python/sglang/srt/layers/quantization/gptq.py | 42 +- .../srt/layers/quantization/modelopt_quant.py | 93 ++-- .../srt/layers/quantization/moe_wna16.py | 39 +- .../sglang/srt/layers/quantization/mxfp4.py | 104 +++-- .../layers/quantization/quark/quark_moe.py | 53 ++- .../sglang/srt/layers/quantization/unquant.py | 112 +++-- .../sglang/srt/layers/quantization/w4afp8.py | 43 +- .../srt/layers/quantization/w8a8_fp8.py | 55 ++- .../srt/layers/quantization/w8a8_int8.py | 102 ++-- python/sglang/srt/managers/schedule_batch.py | 1 - python/sglang/srt/model_loader/__init__.py | 12 +- python/sglang/srt/model_loader/loader.py | 22 +- python/sglang/test/test_cutlass_moe.py | 29 +- test/srt/test_mla_deepseek_v3.py | 37 ++ 34 files changed, 1727 insertions(+), 432 deletions(-) create mode 100644 python/sglang/srt/layers/moe/moe_runner/runner.py create mode 100644 python/sglang/srt/layers/moe/moe_runner/triton.py rename python/sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py => base.py} (52%) diff --git a/python/sglang/srt/eplb/expert_distribution.py b/python/sglang/srt/eplb/expert_distribution.py index 1b3d573d8..e59337323 100644 --- a/python/sglang/srt/eplb/expert_distribution.py +++ b/python/sglang/srt/eplb/expert_distribution.py @@ -11,6 +11,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== + +from __future__ import annotations + import logging import math import os @@ -19,17 +22,19 @@ from abc import ABC from collections import deque from contextlib import contextmanager from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Tuple, Type +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Type import einops import torch import torch.distributed -from sglang.srt.eplb.expert_location import ExpertLocationMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.server_args import ServerArgs from sglang.srt.utils import Withable, get_bool_env_var +if TYPE_CHECKING: + from sglang.srt.eplb.expert_location import ExpertLocationMetadata + logger = logging.getLogger(__name__) # --------------------------------------- Entrypoint ----------------------------------------- @@ -43,7 +48,7 @@ class ExpertDistributionRecorder(ABC): @staticmethod def init_new( server_args: ServerArgs, - expert_location_metadata: "ExpertLocationMetadata", + expert_location_metadata: ExpertLocationMetadata, rank: int, ): if server_args.expert_distribution_recorder_mode is not None: @@ -118,7 +123,7 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder): def __init__( self, server_args: ServerArgs, - expert_location_metadata: "ExpertLocationMetadata", + expert_location_metadata: ExpertLocationMetadata, rank: int, ): self._server_args = server_args @@ -279,7 +284,7 @@ class _SinglePassGatherer(ABC): @staticmethod def init_new( server_args: ServerArgs, - expert_location_metadata: "ExpertLocationMetadata", + expert_location_metadata: ExpertLocationMetadata, rank: int, ) -> "_SinglePassGatherer": if server_args.expert_distribution_recorder_mode == "per_token": @@ -307,7 +312,7 @@ class _SinglePassGatherer(ABC): return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank) - def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int): + def __init__(self, expert_location_metadata: ExpertLocationMetadata, rank: int): self._expert_location_metadata = expert_location_metadata self._rank = rank @@ -346,7 +351,7 @@ class _DetailSinglePassGatherer(_SinglePassGatherer): def __init__( self, server_args: ServerArgs, - expert_location_metadata: "ExpertLocationMetadata", + expert_location_metadata: ExpertLocationMetadata, rank: int, ): super().__init__(expert_location_metadata, rank) @@ -561,7 +566,7 @@ class _Accumulator(ABC): @staticmethod def init_new( server_args: ServerArgs, - expert_location_metadata: "ExpertLocationMetadata", + expert_location_metadata: ExpertLocationMetadata, rank: int, ) -> "_Accumulator": return _Accumulator.get_class(server_args)( @@ -580,7 +585,7 @@ class _Accumulator(ABC): def __init__( self, server_args: ServerArgs, - expert_location_metadata: "ExpertLocationMetadata", + expert_location_metadata: ExpertLocationMetadata, rank: int, ): self._server_args = server_args diff --git a/python/sglang/srt/eplb/expert_location.py b/python/sglang/srt/eplb/expert_location.py index be0e23653..ee5f2c7ca 100644 --- a/python/sglang/srt/eplb/expert_location.py +++ b/python/sglang/srt/eplb/expert_location.py @@ -11,21 +11,26 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== + +from __future__ import annotations + import json import logging import random from dataclasses import dataclass from pathlib import Path -from typing import List, Optional +from typing import TYPE_CHECKING, List, Optional import torch import torch.distributed import torch.nn.functional as F -from sglang.srt.configs.model_config import ModelConfig from sglang.srt.eplb import eplb_algorithms from sglang.srt.model_loader import get_model_architecture -from sglang.srt.server_args import ServerArgs + +if TYPE_CHECKING: + from sglang.srt.configs.model_config import ModelConfig + from sglang.srt.server_args import ServerArgs logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/layers/moe/__init__.py b/python/sglang/srt/layers/moe/__init__.py index e5e5930a2..5c75a3682 100644 --- a/python/sglang/srt/layers/moe/__init__.py +++ b/python/sglang/srt/layers/moe/__init__.py @@ -1,4 +1,4 @@ -from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig +from sglang.srt.layers.moe.moe_runner import MoeRunner, MoeRunnerConfig from sglang.srt.layers.moe.utils import ( DeepEPMode, MoeA2ABackend, @@ -17,6 +17,7 @@ from sglang.srt.layers.moe.utils import ( __all__ = [ "DeepEPMode", "MoeA2ABackend", + "MoeRunner", "MoeRunnerConfig", "MoeRunnerBackend", "initialize_moe_config", diff --git a/python/sglang/srt/layers/moe/fused_moe_native.py b/python/sglang/srt/layers/moe/fused_moe_native.py index 92b88b1b7..a3d3a09bf 100644 --- a/python/sglang/srt/layers/moe/fused_moe_native.py +++ b/python/sglang/srt/layers/moe/fused_moe_native.py @@ -8,16 +8,18 @@ from torch.nn import functional as F from sglang.srt.layers.activation import GeluAndMul, SiluAndMul from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig +from sglang.srt.layers.moe.token_dispatcher import StandardDispatchOutput from sglang.srt.layers.moe.topk import StandardTopKOutput def fused_moe_forward_native( layer: torch.nn.Module, - x: torch.Tensor, - topk_output: StandardTopKOutput, - moe_runner_config: MoeRunnerConfig, + dispatch_output: StandardDispatchOutput, ) -> torch.Tensor: + x, topk_output = dispatch_output + moe_runner_config = layer.moe_runner_config + if moe_runner_config.apply_router_weight_on_input: raise NotImplementedError() diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 4660df676..6d3fb53b0 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -1,3 +1,4 @@ +# NOTE: this file will be separated into sglang/srt/layers/moe/moe_runner/triton_utils.py # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/fused_moe.py """Fused MoE kernel.""" @@ -6,13 +7,12 @@ from __future__ import annotations import functools import os -from typing import List, Optional +from typing import TYPE_CHECKING, List, Optional import torch import triton.language as tl from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig -from sglang.srt.layers.moe.topk import StandardTopKOutput from sglang.srt.utils import ( cpu_has_amx_support, direct_register_custom_op, @@ -26,6 +26,9 @@ from .fused_moe_triton_config import get_config_dtype_str, try_get_optimal_moe_c from .fused_moe_triton_kernels import invoke_fused_moe_kernel, moe_sum_reduce_triton from .moe_align_block_size import moe_align_block_size +if TYPE_CHECKING: + from sglang.srt.layers.moe.topk import StandardTopKOutput + _is_hip = is_hip() _is_cuda = is_cuda() _is_cpu_amx_available = cpu_has_amx_support() diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index b88c60d96..6e9a5f35c 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -23,8 +23,13 @@ from sglang.srt.layers.moe import ( get_moe_runner_backend, should_use_flashinfer_trtllm_moe, ) +from sglang.srt.layers.moe.token_dispatcher.standard import ( + CombineInput, + StandardDispatcher, +) from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker from sglang.srt.layers.quantization.base_config import ( + FusedMoEMethodBase, QuantizationConfig, QuantizeMethodBase, ) @@ -152,16 +157,6 @@ class FusedMoE(torch.nn.Module): self.expert_map_cpu = None self.expert_map_gpu = None - self.moe_runner_config = MoeRunnerConfig( - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - inplace=inplace, - no_combine=no_combine, - routed_scaling_factor=routed_scaling_factor, - gemm1_alpha=gemm1_alpha, - gemm1_clamp_limit=gemm1_clamp_limit, - ) - enable_flashinfer_cutlass_moe = get_moe_runner_backend().is_flashinfer_cutlass() if enable_flashinfer_cutlass_moe and quant_config is None: @@ -196,13 +191,6 @@ class FusedMoE(torch.nn.Module): self.use_presharded_weights = use_presharded_weights self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel() - if quant_config is None: - self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod( - self.use_triton_kernels - ) - else: - self.quant_method = quant_config.get_quant_method(self, prefix) - assert self.quant_method is not None self.quant_config = quant_config self.use_flashinfer_mxfp4_moe = get_moe_runner_backend().is_flashinfer_mxfp4() @@ -213,12 +201,40 @@ class FusedMoE(torch.nn.Module): and self.use_flashinfer_mxfp4_moe ): hidden_size = round_up(hidden_size, 256) + self.hidden_size = hidden_size + + self.moe_runner_config = MoeRunnerConfig( + num_experts=num_experts, + num_local_experts=self.num_local_experts, + hidden_size=hidden_size, + intermediate_size_per_partition=self.intermediate_size_per_partition, + layer_id=layer_id, + top_k=top_k, + num_fused_shared_experts=num_fused_shared_experts, + params_dtype=params_dtype, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + inplace=inplace, + no_combine=no_combine, + routed_scaling_factor=routed_scaling_factor, + gemm1_alpha=gemm1_alpha, + gemm1_clamp_limit=gemm1_clamp_limit, + ) + + if quant_config is None: + self.quant_method: FusedMoEMethodBase = UnquantizedFusedMoEMethod( + self.use_triton_kernels + ) + else: + self.quant_method: FusedMoEMethodBase = quant_config.get_quant_method( + self, prefix + ) + assert self.quant_method is not None + self.quant_method.create_weights( layer=self, num_experts=self.num_local_experts, hidden_size=hidden_size, - # FIXME: figure out which intermediate_size to use - intermediate_size=self.intermediate_size_per_partition, intermediate_size_per_partition=self.intermediate_size_per_partition, params_dtype=params_dtype, weight_loader=( @@ -229,6 +245,9 @@ class FusedMoE(torch.nn.Module): with_bias=with_bias, ) + self.quant_method.create_moe_runner(self, self.moe_runner_config) + self.dispatcher = StandardDispatcher() + def _load_per_tensor_weight_scale( self, shard_id: str, @@ -811,16 +830,17 @@ class FusedMoE(torch.nn.Module): elif TopKOutputChecker.format_is_triton_kernel(topk_output): raise NotImplementedError() - # Matrix multiply. - with use_symmetric_memory(get_tp_group()) as sm: + dispatch_output = self.dispatcher.dispatch( + hidden_states=hidden_states, topk_output=topk_output + ) - final_hidden_states = self.quant_method.apply( - layer=self, - x=hidden_states, - topk_output=topk_output, - moe_runner_config=self.moe_runner_config, - ) - sm.tag(final_hidden_states) + # TODO: consider using symmetric memory + combine_input = self.quant_method.apply( + layer=self, + dispatch_output=dispatch_output, + ) + + final_hidden_states = self.dispatcher.combine(combine_input) final_hidden_states = final_hidden_states[ ..., :origin_hidden_states_dim @@ -955,7 +975,6 @@ class FlashInferFusedMoE(FusedMoE): layer=self, x=hidden_states, topk_output=topk_output, - moe_runner_config=self.moe_runner_config, ) if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1): diff --git a/python/sglang/srt/layers/moe/moe_runner/__init__.py b/python/sglang/srt/layers/moe/moe_runner/__init__.py index 9a7fa9c29..3320a7875 100644 --- a/python/sglang/srt/layers/moe/moe_runner/__init__.py +++ b/python/sglang/srt/layers/moe/moe_runner/__init__.py @@ -1,3 +1,4 @@ from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig +from sglang.srt.layers.moe.moe_runner.runner import MoeRunner -__all__ = ["MoeRunnerConfig"] +__all__ = ["MoeRunnerConfig", "MoeRunner"] diff --git a/python/sglang/srt/layers/moe/moe_runner/base.py b/python/sglang/srt/layers/moe/moe_runner/base.py index 854aeb0e6..c5c14bfea 100644 --- a/python/sglang/srt/layers/moe/moe_runner/base.py +++ b/python/sglang/srt/layers/moe/moe_runner/base.py @@ -1,9 +1,41 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Optional +from typing import TYPE_CHECKING, Callable, Optional, Tuple, TypeGuard + +import torch + +from sglang.srt.layers.moe.token_dispatcher import ( + CombineInput, + CombineInputFormat, + DispatchOutput, + DispatchOutputFormat, +) +from sglang.srt.layers.moe.utils import MoeA2ABackend, MoeRunnerBackend + +if TYPE_CHECKING: + from sglang.srt.layers.moe.moe_runner.triton import ( + TritonRunnerCore, + TritonRunnerInput, + TritonRunnerOutput, + ) @dataclass class MoeRunnerConfig: + + # MoE parameters + num_experts: Optional[int] = None + num_local_experts: Optional[int] = None + hidden_size: Optional[int] = None + intermediate_size_per_partition: Optional[int] = None + layer_id: Optional[int] = None + top_k: Optional[int] = None + num_fused_shared_experts: Optional[int] = None + params_dtype: Optional[torch.dtype] = None + + # Runner configuration activation: str = "silu" apply_router_weight_on_input: bool = False inplace: bool = True @@ -11,3 +43,254 @@ class MoeRunnerConfig: routed_scaling_factor: Optional[float] = None gemm1_alpha: Optional[float] = None gemm1_clamp_limit: Optional[float] = None + + +@dataclass +class RunnerInput(ABC): + + @property + @abstractmethod + def runner_backend(self) -> MoeRunnerBackend: ... + + def runner_backend_is_triton(self) -> TypeGuard[TritonRunnerInput]: + return self.runner_backend == MoeRunnerBackend.TRITON + + +class RunnerOutput(ABC): + + @property + @abstractmethod + def runner_backend(self) -> MoeRunnerBackend: ... + + def runner_backend_is_triton(self) -> TypeGuard[TritonRunnerOutput]: + return self.runner_backend == MoeRunnerBackend.TRITON + + +@dataclass +class MoeQuantInfo(ABC): + """Moe quantization data.""" + + pass + + +class MoeRunnerCore(ABC): + + def __init__(self, config: MoeRunnerConfig): + self.config = config + + @abstractmethod + def run( + self, runner_input: RunnerInput, quant_info: MoeQuantInfo, running_state: dict + ) -> RunnerOutput: + pass + + @property + @abstractmethod + def runner_backend(self) -> MoeRunnerBackend: ... + + def runner_backend_is_triton(self) -> TypeGuard[TritonRunnerCore]: + return self.runner_backend == MoeRunnerBackend.TRITON + + +class FusedOpPool: + + _fused_funcs: dict[str, Callable] = {} + + @classmethod + def register_fused_func( + cls, a2a_backend_name: str, runner_backend_name: str, fused_func: Callable + ): + key = (a2a_backend_name, runner_backend_name) + if key in cls._fused_funcs: + raise ValueError( + f"Fused function for {a2a_backend_name} to {runner_backend_name} is already registered." + ) + assert MoeA2ABackend( + a2a_backend_name + ), f"Invalid dispatch name: {a2a_backend_name}" + assert MoeRunnerBackend( + runner_backend_name + ), f"Invalid runner name: {runner_backend_name}" + cls._fused_funcs[key] = fused_func + + @classmethod + def get_fused_func(cls, dispatch_name: str, runner_name: str) -> Optional[Callable]: + key = (dispatch_name, runner_name) + fused_func = cls._fused_funcs.get(key) + return fused_func + + +class PermuteMethodPool: + + _pre_permute_methods: dict[ + Tuple[DispatchOutputFormat, MoeRunnerBackend], Callable + ] = {} + _post_permute_methods: dict[ + Tuple[MoeRunnerBackend, CombineInputFormat], Callable + ] = {} + + @classmethod + def register_pre_permute( + cls, + dispatch_output_name: str, + runner_backend_name: str, + permute_func: Callable, + ): + """ + Register a customized pre-permute function for the given DispatchOutputFormat and MoeRunnerBackend. + + :param dispatch_output_name: The DispatchOutputFormat name. + :param runner_backend_name: The MoeRunnerBackend name. + :param permute_func: The permute function to register. + """ + key = (dispatch_output_name, runner_backend_name) + if key in cls._pre_permute_methods: + raise ValueError( + f"Pre-permute method for {dispatch_output_name} to {runner_backend_name} is already registered." + ) + assert DispatchOutputFormat( + dispatch_output_name + ), f"Invalid dispatch output name: {dispatch_output_name}" + assert MoeRunnerBackend( + runner_backend_name + ), f"Invalid runner backend name: {runner_backend_name}" + cls._pre_permute_methods[key] = permute_func + + @classmethod + def register_post_permute( + cls, + runner_backend_name: str, + combine_input_name: str, + permute_func: Callable, + ): + """ + Register a customized post-permute function for the given MoeRunnerBackend and CombineInputFormat. + + :param runner_backend_name: The MoeRunnerBackend name. + :param combine_input_name: The CombineInputFormat name. + :param permute_func: The permute function to register. + """ + key = (runner_backend_name, combine_input_name) + if key in cls._post_permute_methods: + raise ValueError( + f"Post-permute method for {runner_backend_name} to {combine_input_name} is already registered." + ) + assert MoeRunnerBackend( + runner_backend_name + ), f"Invalid runner backend name: {runner_backend_name}" + assert CombineInputFormat( + combine_input_name + ), f"Invalid combine input name: {combine_input_name}" + cls._post_permute_methods[key] = permute_func + + @classmethod + def get_pre_permute( + cls, + dispatch_output_format: DispatchOutputFormat, + runner_input_format: MoeRunnerBackend, + ) -> Callable: + """ + Retrieve the pre-permute function for the given DispatchOutputFormat and MoeRunnerBackend. + + :param dispatch_output_format: The DispatchOutputFormat type. + :param runner_input_format: The MoeRunnerBackend type. + :return: The registered permute function or None if not found. + """ + key = (dispatch_output_format, runner_input_format) + pre_permute_func = cls._pre_permute_methods.get(key) + assert ( + pre_permute_func is not None + ), f"Pre-permute function for {dispatch_output_format} to {runner_input_format} is not registered" + return pre_permute_func + + @classmethod + def get_post_permute( + cls, + runner_output_format: MoeRunnerBackend, + combine_input_format: CombineInputFormat, + ) -> Callable: + """ + Retrieve the post-permute function for the given MoeRunnerBackend and CombineInputFormat. + + :param runner_output_format: The MoeRunnerBackend type. + :param combine_input_format: The CombineInputFormat type. + :return: The registered permute function or None if not found. + """ + key = (runner_output_format, combine_input_format) + post_permute_func = cls._post_permute_methods.get(key) + assert ( + post_permute_func is not None + ), f"Post-permute function for {runner_output_format} to {combine_input_format} is not registered" + return post_permute_func + + +def register_fused_func( + a2a_backend_name: str, + runner_backend_name: str, +) -> Callable: + """ + Decorator to register a fused function for the given DispatchOutputFormat and MoeRunnerBackend. + + :param a2a_backend_name: The A2A backend name. + :param runner_backend_name: The MoeRunnerBackend name. + :return: The decorator function. + """ + + def decorator(fused_func: Callable): + FusedOpPool.register_fused_func( + a2a_backend_name, runner_backend_name, fused_func + ) + return fused_func + + return decorator + + +def register_pre_permute( + dispatch_output_name: str, + runner_backend_name: str, +) -> Callable: + """ + Decorator to register a pre-permute function for the given DispatchOutputFormat and MoeRunnerBackend. + + :param dispatch_output_name: The DispatchOutputFormat name. + :param runner_backend_name: The MoeRunnerBackend name. + :return: The decorator function. + """ + + def decorator( + permute_func: Callable[ + [DispatchOutput, MoeQuantInfo, MoeRunnerConfig, dict], RunnerInput + ] + ) -> Callable: + + PermuteMethodPool.register_pre_permute( + dispatch_output_name, runner_backend_name, permute_func + ) + return permute_func + + return decorator + + +def register_post_permute( + runner_backend_name: str, + combine_input_name: str, +) -> Callable: + """ + Decorator to register a post-permute function for the given MoeRunnerBackend and CombineInputFormat. + + :param runner_backend_name: The MoeRunnerBackend name. + :param combine_input_name: The CombineInputFormat name. + :return: The decorator function. + """ + + def decorator( + permute_func: Callable[ + [RunnerOutput, MoeQuantInfo, MoeRunnerConfig, dict], CombineInput + ] + ) -> Callable: + PermuteMethodPool.register_post_permute( + runner_backend_name, combine_input_name, permute_func + ) + return permute_func + + return decorator diff --git a/python/sglang/srt/layers/moe/moe_runner/runner.py b/python/sglang/srt/layers/moe/moe_runner/runner.py new file mode 100644 index 000000000..995813690 --- /dev/null +++ b/python/sglang/srt/layers/moe/moe_runner/runner.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import logging +import os +from typing import TYPE_CHECKING + +from sglang.srt.layers.moe.moe_runner.base import ( + FusedOpPool, + MoeRunnerConfig, + PermuteMethodPool, +) +from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore +from sglang.srt.layers.moe.token_dispatcher.base import ( + CombineInput, + CombineInputFormat, + DispatchOutput, +) +from sglang.srt.layers.moe.utils import get_moe_a2a_backend + +if TYPE_CHECKING: + from sglang.srt.layers.moe.moe_runner.base import MoeQuantInfo + from sglang.srt.layers.moe.utils import MoeRunnerBackend + +logger = logging.getLogger(__name__) + + +class MoeRunner: + + def __init__(self, runner_backend: MoeRunnerBackend, config: MoeRunnerConfig): + self.runner_backend = runner_backend + self.config = config + + self.fused_func = None + + if runner_backend.is_triton(): + self.runner_core = TritonRunnerCore(config) + else: + raise NotImplementedError(f"Unsupported runner backend: {runner_backend}") + + a2a_backend_name = get_moe_a2a_backend().value + runner_backend_name = runner_backend.value + + self.fused_func = FusedOpPool.get_fused_func( + a2a_backend_name, runner_backend_name + ) + + SGLANG_CI_DISABLE_MOE_FUSED_FUNC = os.environ.get( + "SGLANG_CI_DISABLE_MOE_FUSED_FUNC", "0" + ) + if SGLANG_CI_DISABLE_MOE_FUSED_FUNC == "1": + logger.info( + "SGLANG_CI_DISABLE_MOE_FUSED_FUNC is set to 1, disabling fused func" + ) + self.fused_func = None + + def run( + self, dispatch_output: DispatchOutput, quant_info: MoeQuantInfo + ) -> CombineInput: + + if self.fused_func is not None: + return self.fused_func(dispatch_output, quant_info, self.config) + + dispatch_format = dispatch_output.format.value + runner_format = self.runner_core.runner_backend.value + self.pre_permute_func = PermuteMethodPool.get_pre_permute( + dispatch_format, runner_format + ) + + running_state = {} + runner_input = self.pre_permute_func( + dispatch_output, quant_info, self.config, running_state + ) + runner_output = self.runner_core.run(runner_input, quant_info, running_state) + + runner_format = self.runner_core.runner_backend.value + combine_format = dispatch_output.format.value + self.post_permute_func = PermuteMethodPool.get_post_permute( + runner_format, combine_format + ) + combine_input = self.post_permute_func( + runner_output, quant_info, self.config, running_state + ) + + return combine_input diff --git a/python/sglang/srt/layers/moe/moe_runner/triton.py b/python/sglang/srt/layers/moe/moe_runner/triton.py new file mode 100644 index 000000000..bc0476812 --- /dev/null +++ b/python/sglang/srt/layers/moe/moe_runner/triton.py @@ -0,0 +1,442 @@ +from __future__ import annotations + +import functools +import os +from dataclasses import dataclass +from typing import TYPE_CHECKING, List, Optional + +import torch +import triton.language as tl + +from sglang.srt.layers.moe.moe_runner.base import ( + MoeQuantInfo, + MoeRunnerConfig, + MoeRunnerCore, + RunnerInput, + RunnerOutput, + register_fused_func, + register_post_permute, + register_pre_permute, +) +from sglang.srt.layers.moe.token_dispatcher import ( + StandardCombineInput, + StandardDispatchOutput, +) +from sglang.srt.layers.moe.utils import MoeRunnerBackend +from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip + +_is_hip = is_hip() +_is_cuda = is_cuda() +_is_cpu_amx_available = cpu_has_amx_support() +_is_cpu = is_cpu() +_use_aiter = bool(int(os.getenv("SGLANG_MOE_USE_AITER", "0"))) +_MOE_PADDING_SIZE = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0 + + +if _is_cuda: + from sgl_kernel import gelu_and_mul, silu_and_mul +elif _is_cpu and _is_cpu_amx_available: + pass +elif _is_hip: + from vllm import _custom_ops as vllm_ops # gelu_and_mul, silu_and_mul + + if _use_aiter: + try: + from aiter import moe_sum + except ImportError: + raise ImportError("aiter is required when SGLANG_USE_AITER is set to True") + + +if _is_cuda or _is_hip: + from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size + + +@dataclass +class TritonRunnerInput(RunnerInput): + + hidden_states: torch.Tensor + topk_weights: torch.Tensor + topk_ids: torch.Tensor + sorted_token_ids: torch.Tensor + expert_ids: torch.Tensor + num_tokens_post_padded: torch.Tensor + + @property + def runner_backend(self) -> MoeRunnerBackend: + return MoeRunnerBackend.TRITON + + +@dataclass +class TritonRunnerOutput(RunnerOutput): + + hidden_states: torch.Tensor + + @property + def runner_backend(self) -> MoeRunnerBackend: + return MoeRunnerBackend.TRITON + + +@dataclass +class TritonMoeQuantInfo(MoeQuantInfo): + w13_weight: torch.Tensor + w2_weight: torch.Tensor + b13: Optional[torch.Tensor] = None + b2: Optional[torch.Tensor] = None + use_fp8_w8a8: bool = False + use_int8_w8a8: bool = False + use_int8_w8a16: bool = False + use_int4_w4a16: bool = False + per_channel_quant: bool = False + w13_scale: Optional[torch.Tensor] = None + w2_scale: Optional[torch.Tensor] = None + w13_zp: Optional[torch.Tensor] = None + w2_zp: Optional[torch.Tensor] = None + a13_scale: Optional[torch.Tensor] = None + a2_scale: Optional[torch.Tensor] = None + block_shape: Optional[List[int]] = None + + +class TritonRunnerCore(MoeRunnerCore): + + def __init__(self, config: MoeRunnerConfig): + super().__init__(config) + + def run( + self, + runner_input: TritonRunnerInput, + quant_info: TritonMoeQuantInfo, + running_state: dict, + ) -> TritonRunnerOutput: + + # TODO: move these functions to the triton runner + from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( + invoke_fused_moe_kernel, + moe_sum_reduce_torch_compile, + moe_sum_reduce_triton, + swiglu_with_alpha_and_limit, + ) + + hidden_states = runner_input.hidden_states + topk_weights = runner_input.topk_weights + topk_ids = runner_input.topk_ids + sorted_token_ids = runner_input.sorted_token_ids + expert_ids = runner_input.expert_ids + num_tokens_post_padded = runner_input.num_tokens_post_padded + + w13 = quant_info.w13_weight + w2 = quant_info.w2_weight + b13 = quant_info.b13 + b2 = quant_info.b2 + a13_scale = quant_info.a13_scale + a2_scale = quant_info.a2_scale + w13_scale = quant_info.w13_scale + w2_scale = quant_info.w2_scale + w13_zp = quant_info.w13_zp + w2_zp = quant_info.w2_zp + block_shape = quant_info.block_shape + per_channel_quant = quant_info.per_channel_quant + use_fp8_w8a8 = quant_info.use_fp8_w8a8 + use_int8_w8a8 = quant_info.use_int8_w8a8 + use_int8_w8a16 = quant_info.use_int8_w8a16 + use_int4_w4a16 = quant_info.use_int4_w4a16 + + activation = self.config.activation + no_combine = self.config.no_combine + inplace = self.config.inplace + gemm1_alpha = self.config.gemm1_alpha + gemm1_limit = self.config.gemm1_clamp_limit + routed_scaling_factor = self.config.routed_scaling_factor + apply_router_weight_on_input = self.config.apply_router_weight_on_input + + M = hidden_states.shape[0] + E, N, _ = w13.shape + compute_type = ( + tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 + ) + + intermediate_cache1 = torch.empty( + (M, topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + invoke_fused_moe_kernel( + hidden_states, + w13, + b13, + intermediate_cache1, + a13_scale, + w13_scale, + w13_zp, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + apply_router_weight_on_input, + topk_ids.shape[1], + running_state["config"], + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + per_channel_quant=per_channel_quant, + block_shape=block_shape, + ) + + intermediate_cache2 = torch.empty( + (M * topk_ids.shape[1], N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + if activation == "silu": + if gemm1_alpha is not None: + assert gemm1_limit is not None + intermediate_cache2 = swiglu_with_alpha_and_limit( + intermediate_cache1.view(-1, N), + gemm1_alpha, + gemm1_limit, + ) + elif _is_cuda: + silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) + else: + vllm_ops.silu_and_mul( + intermediate_cache2, intermediate_cache1.view(-1, N) + ) + elif activation == "gelu": + assert gemm1_alpha is None, "gemm1_alpha is not supported for gelu" + assert gemm1_limit is None, "gemm1_limit is not supported for gelu" + if _is_cuda: + gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) + else: + vllm_ops.gelu_and_mul( + intermediate_cache2, intermediate_cache1.view(-1, N) + ) + else: + raise ValueError(f"Unsupported activation: {activation=}") + + intermediate_cache3 = torch.empty( + (M, topk_ids.shape[1], w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + if no_combine: + assert not inplace + out_hidden_states = torch.empty( + (M, topk_ids.shape[1], w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + elif inplace: + out_hidden_states = hidden_states + else: + out_hidden_states = torch.empty_like(hidden_states) + + invoke_fused_moe_kernel( + intermediate_cache2, + w2, + b2, + ( + intermediate_cache3 + if not no_combine and topk_ids.shape[1] != 1 + else out_hidden_states.unsqueeze(0) + ), + a2_scale, + w2_scale, + w2_zp, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + not apply_router_weight_on_input, + 1, + running_state["config"], + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + per_channel_quant=per_channel_quant, + block_shape=block_shape, + ) + + if routed_scaling_factor is None: + routed_scaling_factor = 1.0 + + if no_combine: + pass + elif _is_cuda: + if topk_ids.shape[1] == 1 and routed_scaling_factor == 1.0: + pass # we write directly into out_hidden_states + elif topk_ids.shape[1] == 2 and routed_scaling_factor == 1.0: + torch.add( + intermediate_cache3[:, 0], + intermediate_cache3[:, 1], + out=out_hidden_states, + ).squeeze(dim=1) + else: + # According to micro benchmark results, torch.compile can get better performance for small token. + if M <= 32: + moe_sum_reduce_torch_compile( + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states, + routed_scaling_factor, + ) + else: + moe_sum_reduce_triton( + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states, + routed_scaling_factor, + ) + elif _is_hip: + if _use_aiter: + moe_sum( + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states, + ) + else: + vllm_ops.moe_sum( + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states, + ) + else: + vllm_ops.moe_sum( + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states, + ) + + return TritonRunnerOutput( + hidden_states=out_hidden_states, + ) + + @property + def runner_backend(self) -> MoeRunnerBackend: + return MoeRunnerBackend.TRITON + + +@register_fused_func("none", "triton") +def fused_experts_none_to_triton( + dispatch_output: StandardDispatchOutput, + quant_info: TritonMoeQuantInfo, + runner_config: MoeRunnerConfig, +) -> StandardCombineInput: + from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts + + output = fused_experts( + hidden_states=dispatch_output.hidden_states, + w1=quant_info.w13_weight, + w2=quant_info.w2_weight, + topk_output=dispatch_output.topk_output, + moe_runner_config=runner_config, + b1=quant_info.b13, + b2=quant_info.b2, + use_fp8_w8a8=quant_info.use_fp8_w8a8, + use_int8_w8a8=quant_info.use_int8_w8a8, + use_int8_w8a16=quant_info.use_int8_w8a16, + use_int4_w4a16=quant_info.use_int4_w4a16, + per_channel_quant=quant_info.per_channel_quant, + w1_scale=quant_info.w13_scale, + w2_scale=quant_info.w2_scale, + w1_zp=quant_info.w13_zp, + w2_zp=quant_info.w2_zp, + a1_scale=quant_info.a13_scale, + a2_scale=quant_info.a2_scale, + block_shape=quant_info.block_shape, + ) + + return StandardCombineInput( + hidden_states=output, + ) + + +@register_pre_permute("standard", "triton") +def pre_permute_standard_to_triton( + dispatch_output: StandardDispatchOutput, + quant_info: TritonMoeQuantInfo, + runner_config: MoeRunnerConfig, + running_state: dict, +) -> TritonRunnerInput: + + # NOTE: this is dead code as a fused func for standard format is registered. + # This is left here for testing and examples. + + from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( + get_config_dtype_str, + moe_align_block_size, + try_get_optimal_moe_config, + ) + from sglang.srt.layers.moe.topk import TopKOutputChecker + + hidden_states, topk_output = dispatch_output + + assert TopKOutputChecker.format_is_standard(topk_output) + + num_tokens = hidden_states.shape[0] + num_local_experts = runner_config.num_local_experts + + if ( + not (quant_info.use_fp8_w8a8 or quant_info.use_int8_w8a8) + or quant_info.block_shape is not None + or _use_aiter + ): + padding_size = 0 + else: + padding_size = _MOE_PADDING_SIZE + + config_dtype = get_config_dtype_str( + use_fp8_w8a8=quant_info.use_fp8_w8a8, + use_int8_w8a8=quant_info.use_int8_w8a8, + use_int8_w8a16=quant_info.use_int8_w8a16, + use_int4_w4a16=quant_info.use_int4_w4a16, + dtype=hidden_states.dtype, + ) + + get_config_func = functools.partial( + try_get_optimal_moe_config, + quant_info.w13_weight.shape, + ( + num_local_experts, + quant_info.w2_weight.shape[1], + quant_info.w2_weight.shape[2] - padding_size, + ), + topk_output.topk_ids.shape[1], + config_dtype, + block_shape=quant_info.block_shape, + ) + + config = get_config_func(num_tokens) + + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + topk_output.topk_ids, config["BLOCK_SIZE_M"], num_local_experts + ) + + running_state["config"] = config + + return TritonRunnerInput( + hidden_states=hidden_states, + topk_weights=topk_output.topk_weights, + topk_ids=topk_output.topk_ids, + sorted_token_ids=sorted_token_ids, + expert_ids=expert_ids, + num_tokens_post_padded=num_tokens_post_padded, + ) + + +@register_post_permute("triton", "standard") +def post_permute_triton_to_standard( + runner_output: TritonRunnerOutput, + quant_info: TritonMoeQuantInfo, + runner_config: MoeRunnerConfig, + running_state: dict, +) -> StandardCombineInput: + + # NOTE: this is dead code as a fused func for standard format is registered. + # This is left here for testing and examples. + + return StandardCombineInput( + hidden_states=runner_output.hidden_states, + ) diff --git a/python/sglang/srt/layers/moe/token_dispatcher/__init__.py b/python/sglang/srt/layers/moe/token_dispatcher/__init__.py index 7802968ac..82f3ca5cb 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/__init__.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/__init__.py @@ -1,6 +1,9 @@ -from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import ( +from sglang.srt.layers.moe.token_dispatcher.base import ( BaseDispatcher, BaseDispatcherConfig, + CombineInput, + CombineInputChecker, + CombineInputFormat, DispatchOutput, DispatchOutputChecker, DispatchOutputFormat, @@ -9,21 +12,32 @@ from sglang.srt.layers.moe.token_dispatcher.deepep import ( AscendDeepEPLLOutput, DeepEPConfig, DeepEPDispatcher, + DeepEPLLCombineInput, DeepEPLLOutput, + DeepEPNormalCombineInput, DeepEPNormalOutput, ) -from sglang.srt.layers.moe.token_dispatcher.standard import StandardDispatchOutput +from sglang.srt.layers.moe.token_dispatcher.standard import ( + StandardCombineInput, + StandardDispatchOutput, +) __all__ = [ "AscendDeepEPLLOutput", "BaseDispatcher", "BaseDispatcherConfig", + "CombineInput", + "CombineInputChecker", + "CombineInputFormat", "DispatchOutput", "DispatchOutputFormat", "DispatchOutputChecker", "StandardDispatchOutput", + "StandardCombineInput", "DeepEPConfig", "DeepEPDispatcher", "DeepEPNormalOutput", "DeepEPLLOutput", + "DeepEPLLCombineInput", + "DeepEPNormalCombineInput", ] diff --git a/python/sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py b/python/sglang/srt/layers/moe/token_dispatcher/base.py similarity index 52% rename from python/sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py rename to python/sglang/srt/layers/moe/token_dispatcher/base.py index d5ff8cf77..b0ca798ca 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/base.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from enum import Enum, auto +from enum import Enum from typing import TYPE_CHECKING, Protocol, TypeGuard, Union, runtime_checkable import torch @@ -9,10 +9,16 @@ import torch if TYPE_CHECKING: from sglang.srt.layers.moe.token_dispatcher import ( AscendDeepEPLLOutput, + DeepEPLLCombineInput, DeepEPLLOutput, + DeepEPNormalCombineInput, DeepEPNormalOutput, + StandardCombineInput, StandardDispatchOutput, ) + from sglang.srt.layers.moe.topk import TopKOutput + +# ------------------------------ Dispatch Output ------------------------------------- class DispatchOutputChecker: @@ -50,10 +56,10 @@ class DispatchOutputChecker: class DispatchOutputFormat(Enum): - STANDARD = auto() - DEEPEP_NORMAL = auto() - DEEPEP_LL = auto() - ASCENT_LL = auto() + STANDARD = "standard" + DEEPEP_NORMAL = "deepep_normal" + DEEPEP_LL = "deepep_ll" + ASCENT_LL = "ascent_ll" def is_standard(self) -> bool: return self == DispatchOutputFormat.STANDARD @@ -78,10 +84,63 @@ class DispatchOutputFormat(Enum): class DispatchOutput(Protocol): """Protocol for dispatch outputs in different formats.""" + # TODO: add hidden_states to the protocol + @property def format(self) -> DispatchOutputFormat: ... +# ------------------------------ Combine Input ------------------------------------- + + +class CombineInputChecker: + @staticmethod + def format_is_standard( + combine_input: CombineInput, + ) -> TypeGuard[StandardCombineInput]: + return combine_input.format == CombineInputFormat.STANDARD + + @staticmethod + def format_is_deepep_normal( + combine_input: CombineInput, + ) -> TypeGuard[DeepEPNormalCombineInput]: + return combine_input.format == CombineInputFormat.DEEPEP_NORMAL + + @staticmethod + def format_is_deepep_ll( + combine_input: CombineInput, + ) -> TypeGuard[DeepEPLLCombineInput]: + return combine_input.format == CombineInputFormat.DEEPEP_LL + + @staticmethod + def format_is_deepep( + combine_input: CombineInput, + ) -> TypeGuard[Union[DeepEPNormalCombineInput, DeepEPLLCombineInput]]: + return combine_input.format in [ + CombineInputFormat.DEEPEP_NORMAL, + CombineInputFormat.DEEPEP_LL, + ] + + +class CombineInputFormat(Enum): + STANDARD = "standard" + DEEPEP_NORMAL = "deepep_normal" + DEEPEP_LL = "deepep_ll" + + +@runtime_checkable +class CombineInput(Protocol): + """Protocol for combine inputs in different formats.""" + + # TODO: add hidden_states to the protocol + + @property + def format(self) -> CombineInputFormat: ... + + +# ------------------------------ Base Dispatcher ------------------------------------- + + class BaseDispatcherConfig(ABC): """Base class for dispatcher configs.""" @@ -92,9 +151,11 @@ class BaseDispatcher(ABC): """Base class for dispatchers.""" @abstractmethod - def dispatch(self, *args, **kwargs) -> DispatchOutput: + def dispatch( + self, hidden_states: torch.Tensor, topk_output: TopKOutput, **kwargs + ) -> DispatchOutput: pass @abstractmethod - def combine(self, *args, **kwargs) -> torch.Tensor: + def combine(self, combine_input: CombineInput, **kwargs) -> torch.Tensor: pass diff --git a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py index c6ea49089..ccb13e50c 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py @@ -5,13 +5,15 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder -from sglang.srt.layers.moe import DeepEPMode, get_deepep_config, is_tbo_enabled -from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import ( +from sglang.srt.layers.moe.token_dispatcher.base import ( BaseDispatcher, BaseDispatcherConfig, + CombineInput, + CombineInputFormat, DispatchOutput, DispatchOutputFormat, ) +from sglang.srt.layers.moe.utils import DeepEPMode, get_deepep_config, is_tbo_enabled from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.utils import ( get_bool_env_var, @@ -56,6 +58,7 @@ class DeepEPNormalOutput(NamedTuple): """DeepEP normal dispatch output.""" hidden_states: torch.Tensor | Tuple[torch.Tensor, torch.Tensor] + # hidden_states_scale topk_idx: torch.Tensor topk_weights: torch.Tensor num_recv_tokens_per_expert: List[int] @@ -99,6 +102,30 @@ assert isinstance(DeepEPLLOutput, DispatchOutput) assert isinstance(AscendDeepEPLLOutput, DispatchOutput) +class DeepEPNormalCombineInput(NamedTuple): + """DeepEP normal combine input.""" + + pass + + @property + def format(self) -> CombineInputFormat: + return CombineInputFormat.DEEPEP_NORMAL + + +class DeepEPLLCombineInput(NamedTuple): + """DeepEP low latency combine input.""" + + pass + + @property + def format(self) -> CombineInputFormat: + return CombineInputFormat.DEEPEP_LL + + +assert isinstance(DeepEPNormalCombineInput, CombineInput) +assert isinstance(DeepEPLLCombineInput, CombineInput) + + class DeepEPDispatchMode(IntEnum): NORMAL = auto() LOW_LATENCY = auto() diff --git a/python/sglang/srt/layers/moe/token_dispatcher/standard.py b/python/sglang/srt/layers/moe/token_dispatcher/standard.py index 3e09e0bf6..f984104f6 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/standard.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/standard.py @@ -1,19 +1,61 @@ from __future__ import annotations -from typing import NamedTuple +from typing import TYPE_CHECKING, NamedTuple -from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import ( +import torch + +from sglang.srt.layers.moe.token_dispatcher.base import ( + BaseDispatcher, + CombineInput, + CombineInputFormat, DispatchOutput, DispatchOutputFormat, ) +if TYPE_CHECKING: + from sglang.srt.layers.moe.topk import TopKOutput + class StandardDispatchOutput(NamedTuple): """Standard dispatch output.""" + hidden_states: torch.Tensor + topk_output: TopKOutput + @property def format(self) -> DispatchOutputFormat: return DispatchOutputFormat.STANDARD assert isinstance(StandardDispatchOutput, DispatchOutput) + + +class StandardCombineInput(NamedTuple): + """Standard combine input.""" + + hidden_states: torch.Tensor + + @property + def format(self) -> CombineInputFormat: + return CombineInputFormat.STANDARD + + +assert isinstance(StandardCombineInput, CombineInput) + + +class StandardDispatcher(BaseDispatcher): + + def dispatch( + self, hidden_states: torch.Tensor, topk_output: TopKOutput + ) -> DispatchOutput: + return StandardDispatchOutput( + hidden_states=hidden_states, topk_output=topk_output + ) + + def combine(self, combine_input: CombineInput) -> torch.Tensor: + if isinstance(combine_input, StandardCombineInput): + return combine_input.hidden_states + else: + # TODO: this branch should be removed in the future + assert isinstance(combine_input, torch.Tensor) + return combine_input diff --git a/python/sglang/srt/layers/moe/utils.py b/python/sglang/srt/layers/moe/utils.py index 1be17ea68..b4e4ec424 100644 --- a/python/sglang/srt/layers/moe/utils.py +++ b/python/sglang/srt/layers/moe/utils.py @@ -1,6 +1,7 @@ from __future__ import annotations import importlib.util +import logging from enum import Enum from functools import lru_cache from typing import TYPE_CHECKING, Optional @@ -12,11 +13,12 @@ from sglang.srt.layers.dp_attention import ( get_attention_dp_size, is_dp_attention_enabled, ) -from sglang.srt.utils import logger if TYPE_CHECKING: from sglang.srt.server_args import ServerArgs +logger = logging.getLogger(__name__) + class MoeA2ABackend(Enum): @@ -131,7 +133,7 @@ def get_moe_a2a_backend() -> MoeA2ABackend: global MOE_A2A_BACKEND if MOE_A2A_BACKEND is None: logger.warning("MOE_A2A_BACKEND is not initialized, using default backend") - MOE_A2A_BACKEND = MoeA2ABackend(None) + MOE_A2A_BACKEND = MoeA2ABackend.NONE return MOE_A2A_BACKEND @@ -139,7 +141,7 @@ def get_moe_runner_backend() -> MoeRunnerBackend: global MOE_RUNNER_BACKEND if MOE_RUNNER_BACKEND is None: logger.warning("MOE_RUNNER_BACKEND is not initialized, using triton backend") - MOE_RUNNER_BACKEND = MoeRunnerBackend("triton") + MOE_RUNNER_BACKEND = MoeRunnerBackend.AUTO return MOE_RUNNER_BACKEND @@ -147,7 +149,7 @@ def get_deepep_mode() -> DeepEPMode: global DEEPEP_MODE if DEEPEP_MODE is None: logger.warning("DEEPEP_MODE is not initialized, using auto mode") - DEEPEP_MODE = DeepEPMode("auto") + DEEPEP_MODE = DeepEPMode.AUTO return DEEPEP_MODE diff --git a/python/sglang/srt/layers/quantization/awq.py b/python/sglang/srt/layers/quantization/awq.py index 19deb7dd1..9cba60c2b 100644 --- a/python/sglang/srt/layers/quantization/awq.py +++ b/python/sglang/srt/layers/quantization/awq.py @@ -34,7 +34,10 @@ from sglang.srt.layers.quantization.utils import get_scalar_types, replace_param if TYPE_CHECKING: from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig - from sglang.srt.layers.moe.topk import StandardTopKOutput + from sglang.srt.layers.moe.token_dispatcher import ( + StandardDispatchOutput, + CombineInput, + ) from sglang.srt.utils import is_cuda, is_hip @@ -736,24 +739,32 @@ class AWQMoEMethod(FusedMoEMethodBase): ) replace_parameter(layer, "w2_qzeros", marlin_w2_zp) + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + def apply( self, layer: torch.nn.Module, - x: torch.Tensor, - topk_output: StandardTopKOutput, - moe_runner_config: MoeRunnerConfig, - ) -> torch.Tensor: + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + assert ( - moe_runner_config.activation == "silu" + self.moe_runner_config.activation == "silu" ), "Only SiLU activation is supported." # The input must currently be float16 + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + orig_dtype = x.dtype x = x.half() topk_weights, topk_ids, router_logits = topk_output - return fused_marlin_moe( + output = fused_marlin_moe( x, layer.w13_qweight, layer.w2_qweight, @@ -768,3 +779,4 @@ class AWQMoEMethod(FusedMoEMethodBase): w2_zeros=layer.w2_qzeros, num_bits=self.quant_config.weight_bits, ).to(orig_dtype) + return StandardCombineInput(hidden_states=output) diff --git a/python/sglang/srt/layers/quantization/base_config.py b/python/sglang/srt/layers/quantization/base_config.py index ec2b4edb1..4a5b7905e 100644 --- a/python/sglang/srt/layers/quantization/base_config.py +++ b/python/sglang/srt/layers/quantization/base_config.py @@ -3,6 +3,7 @@ from __future__ import annotations import inspect from abc import ABC, abstractmethod +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type import torch @@ -10,7 +11,7 @@ from torch import nn if TYPE_CHECKING: from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig - from sglang.srt.layers.moe.topk import TopKOutput + from sglang.srt.layers.moe.token_dispatcher import CombineInput, DispatchOutput class QuantizeMethodBase(ABC): @@ -89,20 +90,24 @@ class FusedMoEMethodBase(QuantizeMethodBase): layer: torch.nn.Module, num_experts: int, hidden_size: int, - intermediate_size: int, + intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs, ): raise NotImplementedError + @abstractmethod + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + raise NotImplementedError + @abstractmethod def apply( self, layer: torch.nn.Module, - x: torch.Tensor, - topk_output: TopKOutput, - moe_runner_config: MoeRunnerConfig, - ) -> torch.Tensor: + dispatch_output: DispatchOutput, + ) -> CombineInput: raise NotImplementedError diff --git a/python/sglang/srt/layers/quantization/blockwise_int8.py b/python/sglang/srt/layers/quantization/blockwise_int8.py index a5966c4d5..60d4e3929 100644 --- a/python/sglang/srt/layers/quantization/blockwise_int8.py +++ b/python/sglang/srt/layers/quantization/blockwise_int8.py @@ -9,6 +9,8 @@ import torch from torch.nn import Module from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig +from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo from sglang.srt.layers.parameter import BlockQuantScaleParameter, ModelWeightParameter from sglang.srt.layers.quantization.base_config import ( FusedMoEMethodBase, @@ -22,8 +24,10 @@ from sglang.srt.layers.quantization.utils import is_layer_skipped from sglang.srt.utils import set_weight_attrs if TYPE_CHECKING: - from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig - from sglang.srt.layers.moe.topk import TopKOutput + from sglang.srt.layers.moe.token_dispatcher import ( + CombineInput, + StandardDispatchOutput, + ) ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -257,7 +261,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase): layer: Module, num_experts: int, hidden_size: int, - intermediate_size: int, + intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs, ): @@ -273,25 +277,28 @@ class BlockInt8MoEMethod(FusedMoEMethodBase): ) # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n. # Required by column parallel or enabling merged weights - if intermediate_size % block_n != 0: + if intermediate_size_per_partition % block_n != 0: raise ValueError( f"The output_size of gate's and up's weight = " - f"{intermediate_size} is not divisible by " + f"{intermediate_size_per_partition} is not divisible by " f"weight quantization block_n = {block_n}." ) if tp_size > 1: # Required by row parallel - if intermediate_size % block_k != 0: + if intermediate_size_per_partition % block_k != 0: raise ValueError( f"The input_size of down's weight = " - f"{intermediate_size} is not divisible by " + f"{intermediate_size_per_partition} is not divisible by " f"weight quantization block_k = {block_k}." ) # WEIGHTS w13_weight = torch.nn.Parameter( torch.empty( - num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype, ), requires_grad=False, ) @@ -300,7 +307,10 @@ class BlockInt8MoEMethod(FusedMoEMethodBase): w2_weight = torch.nn.Parameter( torch.empty( - num_experts, hidden_size, intermediate_size, dtype=params_dtype + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype, ), requires_grad=False, ) @@ -311,7 +321,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase): w13_weight_scale = torch.nn.Parameter( torch.ones( num_experts, - 2 * ((intermediate_size + block_n - 1) // block_n), + 2 * ((intermediate_size_per_partition + block_n - 1) // block_n), (hidden_size + block_k - 1) // block_k, dtype=torch.float32, ), @@ -321,7 +331,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase): torch.ones( num_experts, (hidden_size + block_n - 1) // block_n, - (intermediate_size + block_k - 1) // block_k, + (intermediate_size_per_partition + block_k - 1) // block_k, dtype=torch.float32, ), requires_grad=False, @@ -344,26 +354,27 @@ class BlockInt8MoEMethod(FusedMoEMethodBase): # Block quant doesn't need to process weights after loading return + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) + def apply( self, layer: torch.nn.Module, - x: torch.Tensor, - topk_output: TopKOutput, - moe_runner_config: MoeRunnerConfig, - ) -> torch.Tensor: - from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: - # Expert fusion with INT8 quantization - return fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_output=topk_output, - moe_runner_config=moe_runner_config, + quant_info = TritonMoeQuantInfo( + w13_weight=layer.w13_weight, + w2_weight=layer.w2_weight, use_int8_w8a8=True, - w1_scale=(layer.w13_weight_scale_inv), - w2_scale=(layer.w2_weight_scale_inv), - a1_scale=layer.w13_input_scale, + w13_scale=layer.w13_weight_scale_inv, + w2_scale=layer.w2_weight_scale_inv, + a13_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, block_shape=self.quant_config.weight_block_size, ) + + return self.runner.run(dispatch_output, quant_info) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 320a7ba87..e2ff25e68 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -11,6 +11,8 @@ import torch from compressed_tensors import CompressionFormat from compressed_tensors.quantization import QuantizationStrategy +from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig +from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz @@ -30,8 +32,10 @@ from sglang.srt.utils import ( if TYPE_CHECKING: from sglang.srt.layers.moe.fused_moe_triton import FusedMoE - from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig - from sglang.srt.layers.moe.topk import TopKOutput + from sglang.srt.layers.moe.token_dispatcher import ( + CombineInput, + StandardDispatchOutput, + ) from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import ( CompressedTensorsConfig, ) @@ -293,14 +297,24 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ) torch.cuda.empty_cache() + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) + def apply( self, layer: torch.nn.Module, - x: torch.Tensor, - topk_output: TopKOutput, - moe_runner_config: MoeRunnerConfig, - ) -> torch.Tensor: - from sglang.srt.layers.moe.fused_moe_triton import fused_experts + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + + moe_runner_config = self.moe_runner_config if ( _use_aiter @@ -308,7 +322,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): and moe_runner_config.apply_router_weight_on_input ): topk_weights, topk_ids, _ = topk_output - return rocm_fused_experts_tkw1( + output = rocm_fused_experts_tkw1( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -324,21 +338,20 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, ) + return StandardCombineInput(hidden_states=output) else: - return fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_output=topk_output, - moe_runner_config=moe_runner_config, + quant_info = TritonMoeQuantInfo( + w13_weight=layer.w13_weight, + w2_weight=layer.w2_weight, use_fp8_w8a8=True, per_channel_quant=self.weight_quant.strategy == QuantizationStrategy.CHANNEL, - w1_scale=layer.w13_weight_scale, + w13_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, + a13_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, ) + return self.runner.run(dispatch_output, quant_info) class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): @@ -380,8 +393,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): params_dtype == torch.float16 ), "float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501 - intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full") - # Will transpose the loaded weight along the # intermediate and hidden dim sizes. Will # shard for TP along the transposed dims @@ -415,13 +426,13 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): # In the case where we have actorder/g_idx, # we do not partition the w2 scales load_full_w2 = self.actorder and self.group_size != -1 - w2_scales_size = ( - intermediate_size_full if load_full_w2 else intermediate_size_per_partition - ) - self.is_k_full = (not self.actorder) or ( - intermediate_size_per_partition == intermediate_size_full - ) + if load_full_w2: + w2_scales_size = intermediate_size_per_partition * layer.moe_tp_size + else: + w2_scales_size = intermediate_size_per_partition + + self.is_k_full = (not self.actorder) or layer.moe_tp_size == 1 if self.strategy == "channel": num_groups_w2 = num_groups_w13 = 1 @@ -640,21 +651,29 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ) replace_tensor("w2_weight_scale", marlin_w2_scales) + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + def apply( self, layer: torch.nn.Module, - x: torch.Tensor, - topk_output: TopKOutput, - moe_runner_config: MoeRunnerConfig, - ) -> torch.Tensor: + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput assert ( - moe_runner_config.activation == "silu" + self.moe_runner_config.activation == "silu" ), "Only SiLU activation is supported." + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + topk_weights, topk_ids, router_logits = topk_output - return torch.ops.vllm.fused_marlin_moe( + output = torch.ops.vllm.fused_marlin_moe( x, layer.w13_weight_packed, layer.w2_weight_packed, @@ -670,3 +689,4 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): num_bits=self.num_bits, is_k_full=self.is_k_full, ) + return StandardCombineInput(hidden_states=output) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 4915d4d08..89938f4c3 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -30,6 +30,9 @@ except ImportError: from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading +from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig +from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo +from sglang.srt.layers.moe.token_dispatcher.base import DispatchOutputChecker from sglang.srt.layers.parameter import ( BlockQuantScaleParameter, ModelWeightParameter, @@ -81,7 +84,11 @@ from sglang.srt.utils import ( ) if TYPE_CHECKING: - from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig + from sglang.srt.layers.moe.token_dispatcher import ( + CombineInput, + DispatchOutput, + StandardDispatchOutput, + ) from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config @@ -527,7 +534,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer: Module, num_experts: int, hidden_size: int, - intermediate_size: int, + intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs, ): @@ -543,18 +550,18 @@ class Fp8MoEMethod(FusedMoEMethodBase): ) # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n. # Required by column parallel or enabling merged weights - if intermediate_size % block_n != 0: + if intermediate_size_per_partition % block_n != 0: raise ValueError( f"The output_size of gate's and up's weight = " - f"{intermediate_size} is not divisible by " + f"{intermediate_size_per_partition} is not divisible by " f"weight quantization block_n = {block_n}." ) if tp_size > 1: # Required by row parallel - if intermediate_size % block_k != 0: + if intermediate_size_per_partition % block_k != 0: raise ValueError( f"The input_size of down's weight = " - f"{intermediate_size} is not divisible by " + f"{intermediate_size_per_partition} is not divisible by " f"weight quantization block_k = {block_k}." ) @@ -564,7 +571,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): w13_weight = torch.nn.Parameter( torch.empty( num_experts, - 2 * intermediate_size, + 2 * intermediate_size_per_partition, hidden_size // 8, dtype=params_dtype, ), @@ -572,20 +579,29 @@ class Fp8MoEMethod(FusedMoEMethodBase): ) w2_weight = torch.nn.Parameter( torch.empty( - num_experts, hidden_size, intermediate_size // 8, dtype=params_dtype + num_experts, + hidden_size, + intermediate_size_per_partition // 8, + dtype=params_dtype, ), requires_grad=False, ) else: w13_weight = torch.nn.Parameter( torch.empty( - num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype, ), requires_grad=False, ) w2_weight = torch.nn.Parameter( torch.empty( - num_experts, hidden_size, intermediate_size, dtype=params_dtype + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype, ), requires_grad=False, ) @@ -601,7 +617,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): w13_weight_scale = torch.nn.Parameter( torch.ones( num_experts, - 2 * ((intermediate_size + block_n - 1) // block_n), + 2 * ((intermediate_size_per_partition + block_n - 1) // block_n), (hidden_size + block_k - 1) // block_k, dtype=torch.float32, ), @@ -611,7 +627,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): torch.ones( num_experts, (hidden_size + block_n - 1) // block_n, - (intermediate_size + block_k - 1) // block_k, + (intermediate_size_per_partition + block_k - 1) // block_k, dtype=torch.float32, ), requires_grad=False, @@ -632,19 +648,19 @@ class Fp8MoEMethod(FusedMoEMethodBase): ) self.c_strides1 = torch.full( (num_experts,), - 2 * intermediate_size, + 2 * intermediate_size_per_partition, device=w13_weight.device, dtype=torch.int64, ) self.ab_strides2 = torch.full( (num_experts,), - intermediate_size, + intermediate_size_per_partition, device=w2_weight.device, dtype=torch.int64, ) self.c_strides2 = torch.full( (num_experts,), - hidden_size, + intermediate_size_per_partition, device=w2_weight.device, dtype=torch.int64, ) @@ -691,7 +707,11 @@ class Fp8MoEMethod(FusedMoEMethodBase): if _is_hip: # _use_aiter: TODO: add check back after triton kernel # ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling w13_weight_scale1 = torch.nn.Parameter( - torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32), + torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + dtype=torch.float32, + ), requires_grad=False, ) w2_weight_scale1 = torch.nn.Parameter( @@ -984,14 +1004,23 @@ class Fp8MoEMethod(FusedMoEMethodBase): ) torch.cuda.empty_cache() + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) + def apply( self, layer: torch.nn.Module, - x: torch.Tensor, - topk_output: TopKOutput, - moe_runner_config: MoeRunnerConfig, - ) -> torch.Tensor: - from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts + dispatch_output: DispatchOutput, + ) -> CombineInput: + + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + moe_runner_config = self.moe_runner_config if use_intel_amx_backend(layer): from sglang.srt.layers.moe.topk import apply_topk_weights_cpu @@ -1001,7 +1030,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): moe_runner_config.apply_router_weight_on_input, topk_weights, x ) - return torch.ops.sgl_kernel.fused_experts_cpu( + output = torch.ops.sgl_kernel.fused_experts_cpu( x, layer.w13_weight, layer.w2_weight, @@ -1017,6 +1046,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): None, # a2_scale True, # is_vnni ) + return StandardCombineInput(hidden_states=output) if _is_hip: ret = self.maybe_apply_hip_fused_experts( @@ -1027,7 +1057,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): moe_runner_config.no_combine, ) if ret is not None: - return ret + return StandardCombineInput(hidden_states=ret) if self.use_cutlass_fused_experts_fp8: from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8 @@ -1056,17 +1086,13 @@ class Fp8MoEMethod(FusedMoEMethodBase): self.problem_sizes2, use_fp8_blockscale=True, ) - # Scale by routed_scaling_factor is fused into select_experts. - return output - # Expert fusion with FP8 quantization - return fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_output=topk_output, - moe_runner_config=moe_runner_config, + return StandardCombineInput(hidden_states=output) + + quant_info = TritonMoeQuantInfo( + w13_weight=layer.w13_weight, + w2_weight=layer.w2_weight, use_fp8_w8a8=True, - w1_scale=( + w13_scale=( layer.w13_weight_scale_inv if self.block_quant else layer.w13_weight_scale @@ -1074,20 +1100,22 @@ class Fp8MoEMethod(FusedMoEMethodBase): w2_scale=( layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale ), - a1_scale=layer.w13_input_scale, + a13_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, block_shape=self.quant_config.weight_block_size, ) + return self.runner.run(dispatch_output, quant_info) def apply_with_router_logits( self, layer: torch.nn.Module, - x: torch.Tensor, - topk_output: TopKOutput, - moe_runner_config: MoeRunnerConfig, + dispatch_output: StandardDispatchOutput, ) -> torch.Tensor: - activation = moe_runner_config.activation - routed_scaling_factor = moe_runner_config.routed_scaling_factor + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + + activation = self.moe_runner_config.activation + routed_scaling_factor = self.moe_runner_config.routed_scaling_factor from flashinfer.fused_moe import trtllm_fp8_block_scale_moe diff --git a/python/sglang/srt/layers/quantization/gptq.py b/python/sglang/srt/layers/quantization/gptq.py index c770708b0..ccd3d46f7 100644 --- a/python/sglang/srt/layers/quantization/gptq.py +++ b/python/sglang/srt/layers/quantization/gptq.py @@ -45,7 +45,10 @@ from sglang.srt.layers.quantization.utils import ( if TYPE_CHECKING: from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig - from sglang.srt.layers.moe.topk import TopKOutput + from sglang.srt.layers.moe.token_dispatcher import ( + StandardDispatchOutput, + CombineInput, + ) from sglang.srt.utils import is_cuda @@ -838,19 +841,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): from sglang.srt.layers.linear import set_weight_attrs from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported - intermediate_size = extra_weight_attrs.pop("intermediate_size") - - self.is_k_full = (not self.quant_config.desc_act) or ( - intermediate_size_per_partition == intermediate_size - ) + self.is_k_full = (not self.quant_config.desc_act) or layer.moe_tp_size == 1 if self.quant_config.group_size != -1: scales_size13 = hidden_size // self.quant_config.group_size - w2_scales_size = ( - intermediate_size - if self.quant_config.desc_act - else intermediate_size_per_partition - ) + if self.quant_config.desc_act: + w2_scales_size = intermediate_size_per_partition + else: + w2_scales_size = intermediate_size_per_partition * layer.moe_tp_size scales_size2 = w2_scales_size // self.quant_config.group_size strategy = FusedMoeWeightScaleSupported.GROUP.value else: @@ -1052,17 +1050,26 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ) replace_parameter(layer, "w2_scales", marlin_w2_scales) + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + def apply( self, layer: torch.nn.Module, - x: torch.Tensor, - topk_output: TopKOutput, - moe_runner_config: MoeRunnerConfig, - ) -> torch.Tensor: + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + # Delay the import to avoid circular dependency assert ( - moe_runner_config.activation == "silu" + self.moe_runner_config.activation == "silu" ), "Only SiLU activation is supported." # The input must currently be float16 @@ -1071,7 +1078,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): topk_weights, topk_ids, router_logits = topk_output - return fused_marlin_moe( + output = fused_marlin_moe( x, layer.w13_qweight, layer.w2_qweight, @@ -1087,3 +1094,4 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): num_bits=self.quant_config.weight_bits, is_k_full=self.is_k_full, ).to(orig_dtype) + return StandardCombineInput(hidden_states=output) diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index bd4367234..eb9bc2f97 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -10,10 +10,14 @@ from torch.nn.parameter import Parameter from sglang.srt.distributed import get_tp_group from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer from sglang.srt.layers.moe import ( + MoeRunner, + MoeRunnerBackend, + MoeRunnerConfig, should_use_flashinfer_cutlass_moe_fp4_allgather, should_use_flashinfer_trtllm_moe, ) from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType +from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter from sglang.srt.layers.quantization.base_config import ( FusedMoEMethodBase, @@ -39,8 +43,10 @@ from sglang.srt.utils import is_cuda, next_power_of_2 if TYPE_CHECKING: from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE - from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig - from sglang.srt.layers.moe.topk import TopKOutput + from sglang.srt.layers.moe.token_dispatcher import ( + CombineInput, + StandardDispatchOutput, + ) if is_cuda(): from sgl_kernel import scaled_fp4_quant @@ -322,7 +328,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): layer: torch.nn.Module, num_experts: int, hidden_size: int, - intermediate_size: int, + intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs, ): @@ -338,7 +344,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): w13_weight = ModelWeightParameter( data=torch.empty( - num_experts, 2 * intermediate_size, hidden_size, dtype=weight_dtype + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=weight_dtype, ), input_dim=2, output_dim=1, @@ -348,7 +357,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): w2_weight = ModelWeightParameter( data=torch.empty( - num_experts, hidden_size, intermediate_size, dtype=weight_dtype + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=weight_dtype, ), input_dim=2, output_dim=1, @@ -414,28 +426,28 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): max_w13_scales = layer.w13_weight_scale.max(dim=1).values # Requantize each expert's weights using the combined scale - # w13_weight has shape (num_experts, 2 * intermediate_size, hidden_size) - # where the first intermediate_size rows are w1, the next are w3 - intermediate_size = layer.w13_weight.shape[1] // 2 + # w13_weight has shape (num_experts, 2 * intermediate_size_per_partition, hidden_size) + # where the first intermediate_size_per_partition rows are w1, the next are w3 + intermediate_size_per_partition = layer.w13_weight.shape[1] // 2 for expert_id in range(layer.w13_weight.shape[0]): start = 0 for shard_id in range(2): # w1 and w3 # Dequantize using the original scale for this shard dq_weight = per_tensor_dequantize( layer.w13_weight[expert_id][ - start : start + intermediate_size, : + start : start + intermediate_size_per_partition, : ], layer.w13_weight_scale[expert_id][shard_id], ) # Requantize using the combined max scale ( layer.w13_weight[expert_id][ - start : start + intermediate_size, : + start : start + intermediate_size_per_partition, : ], _, ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) - start += intermediate_size + start += intermediate_size_per_partition # Update the scale parameter to be per-expert instead of per-shard layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False) @@ -457,29 +469,31 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): layer.w2_input_scale.max(), requires_grad=False ) + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) + def apply( self, layer: torch.nn.Module, - x: torch.Tensor, - topk_output: TopKOutput, - moe_runner_config: MoeRunnerConfig, - ) -> torch.Tensor: - from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: - return fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_output=topk_output, - moe_runner_config=moe_runner_config, + quant_info = TritonMoeQuantInfo( + w13_weight=layer.w13_weight, + w2_weight=layer.w2_weight, use_fp8_w8a8=True, - per_channel_quant=False, # ModelOpt uses per-tensor quantization - w1_scale=layer.w13_weight_scale, + per_channel_quant=False, + w13_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, + a13_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, ) + return self.runner.run(dispatch_output, quant_info) + class ModelOptFp4Config(QuantizationConfig): """Config class for FP4.""" @@ -1278,21 +1292,32 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): # FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13 return self.enable_flashinfer_cutlass_moe + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + def apply( self, layer: FusedMoE, - x: torch.Tensor, - topk_output: TopKOutput, - moe_runner_config: MoeRunnerConfig, - ) -> torch.Tensor: + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + assert ( - moe_runner_config.activation == "silu" + self.moe_runner_config.activation == "silu" ), "Only SiLU activation is supported." + moe_runner_config = self.moe_runner_config + # Check if this is a FlashInferFP4MoE layer that should handle its own forward if hasattr(layer, "gemm1_weights_fp4_shuffled"): # This layer was processed with flashinfer TRTLLM - delegate to its own forward - return layer.forward(x, topk_output) + return StandardCombineInput(hidden_states=layer.forward(x, topk_output)) if self.enable_flashinfer_cutlass_moe: assert ( @@ -1345,13 +1370,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): tp_rank=layer.moe_tp_rank, tune_max_num_tokens=next_power_of_2(x.shape[0]), )[0] - # Scale by routed_scaling_factor is fused into select_experts. if should_use_flashinfer_cutlass_moe_fp4_allgather(): output, global_output = get_local_dp_buffer(), output get_tp_group().reduce_scatterv( global_output, output=output, sizes=get_dp_global_num_tokens() ) - return output + return StandardCombineInput(hidden_states=output) from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 @@ -1372,4 +1396,5 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input, ).to(x.dtype) # Scale by routed_scaling_factor is fused into select_experts. - return output + + return StandardCombineInput(hidden_states=output) diff --git a/python/sglang/srt/layers/quantization/moe_wna16.py b/python/sglang/srt/layers/quantization/moe_wna16.py index 7f2c78cbb..531e4271f 100644 --- a/python/sglang/srt/layers/quantization/moe_wna16.py +++ b/python/sglang/srt/layers/quantization/moe_wna16.py @@ -9,6 +9,8 @@ import torch from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed.parallel_state import get_tp_group +from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig +from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo from sglang.srt.layers.quantization.awq import AWQConfig from sglang.srt.layers.quantization.base_config import ( FusedMoEMethodBase, @@ -22,8 +24,10 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs logger = logging.getLogger(__name__) if TYPE_CHECKING: - from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig - from sglang.srt.layers.moe.topk import TopKOutput + from sglang.srt.layers.moe.token_dispatcher import ( + CombineInput, + StandardDispatchOutput, + ) def get_weight_perm(num_bits: int): @@ -349,37 +353,36 @@ class MoeWNA16Method(FusedMoEMethodBase): layer.register_parameter(key, param) set_weight_attrs(param, extra_weight_attrs) + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) + def apply( self, layer: torch.nn.Module, - x: torch.Tensor, - topk_output: TopKOutput, - moe_runner_config: MoeRunnerConfig, - ) -> torch.Tensor: - # avoid circular import - from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts - + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: assert ( - moe_runner_config.activation == "silu" + self.moe_runner_config.activation == "silu" ), "Only SiLU activation is supported." weight_bits = self.quant_config.weight_bits has_zp = self.quant_config.has_zp - return fused_experts( - x, - layer.w13_qweight, - layer.w2_qweight, - topk_output=topk_output, - moe_runner_config=moe_runner_config, + quant_info = TritonMoeQuantInfo( + w13_weight=layer.w13_qweight, + w2_weight=layer.w2_qweight, use_int4_w4a16=weight_bits == 4, use_int8_w8a16=weight_bits == 8, - w1_scale=layer.w13_scales, + w13_scale=layer.w13_scales, w2_scale=layer.w2_scales, - w1_zp=layer.w13_qzeros if has_zp else None, + w13_zp=layer.w13_qzeros if has_zp else None, w2_zp=layer.w2_qzeros if has_zp else None, block_shape=[0, layer.group_size], ) + return self.runner.run(dispatch_output, quant_info) @staticmethod def get_weight_loader(layer, weight_loader): diff --git a/python/sglang/srt/layers/quantization/mxfp4.py b/python/sglang/srt/layers/quantization/mxfp4.py index 8180fb8b9..0d98d00d6 100644 --- a/python/sglang/srt/layers/quantization/mxfp4.py +++ b/python/sglang/srt/layers/quantization/mxfp4.py @@ -22,6 +22,8 @@ from typing import TYPE_CHECKING, List, Optional import torch from torch.nn.parameter import Parameter +from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig +from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo from sglang.srt.layers.moe.utils import get_moe_runner_backend from sglang.srt.layers.quantization.base_config import ( FusedMoEMethodBase, @@ -59,8 +61,10 @@ if is_flashinfer_available(): logger = logging.getLogger(__name__) if TYPE_CHECKING: - from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig - from sglang.srt.layers.moe.topk import TopKOutput + from sglang.srt.layers.moe.token_dispatcher import ( + CombineInput, + StandardDispatchOutput, + ) _is_hip = is_hip() @@ -283,7 +287,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): layer: torch.nn.Module, num_experts: int, hidden_size: int, - intermediate_size: int, + intermediate_size_per_partition: int, params_dtype: torch.dtype, with_bias: bool = False, **extra_weight_attrs, @@ -296,26 +300,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): # pad the intermediate size to be a multiple of 2 * mxfp4_block # for to hold non-uniform sharded tensor as well as swizzling - intermediate_size_per_partition_after_pad = intermediate_size + intermediate_size_per_partition_after_pad = intermediate_size_per_partition if _is_sm100_supported: if self.use_flashinfer: intermediate_size_per_partition_after_pad = round_up( - intermediate_size, 256 + intermediate_size_per_partition, 256 ) hidden_size = round_up(hidden_size, 256) else: intermediate_size_per_partition_after_pad = round_up( - intermediate_size, 64 + intermediate_size_per_partition, 64 ) elif has_triton_kernels: # TODO: this is a hack to make # intermediate_size_per_partition_after_pad the same as the # per_rank_intermediate_size during weight loading intermediate_size_per_partition_after_pad = round_up( - intermediate_size, mxfp4_block + intermediate_size_per_partition, mxfp4_block ) - self.intermediate_size = intermediate_size_per_partition_after_pad + self.intermediate_size_per_partition = intermediate_size_per_partition_after_pad self.hidden_size = hidden_size # Fused gate_up_proj (column parallel) @@ -410,31 +414,35 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): assert ( layer.w13_weight.dim() == 3 and layer.w13_weight.shape[0] == self.num_experts - and layer.w13_weight.shape[1] == self.intermediate_size * 2 + and layer.w13_weight.shape[1] + == self.intermediate_size_per_partition * 2 and layer.w13_weight.shape[2] == self.hidden_size // 2 ) assert ( layer.w13_weight_scale.dim() == 3 and layer.w13_weight_scale.shape[0] == self.num_experts - and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2 + and layer.w13_weight_scale.shape[1] + == self.intermediate_size_per_partition * 2 and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size ) assert ( layer.w2_weight.dim() == 3 and layer.w2_weight.shape[0] == self.num_experts and layer.w2_weight.shape[1] == self.hidden_size - and layer.w2_weight.shape[2] == self.intermediate_size // 2 + and layer.w2_weight.shape[2] + == self.intermediate_size_per_partition // 2 ) assert ( layer.w2_weight_scale.dim() == 3 and layer.w2_weight_scale.shape[1] == self.hidden_size and layer.w2_weight_scale.shape[2] - == self.intermediate_size // sf_block_size + == self.intermediate_size_per_partition // sf_block_size ) assert ( layer.w13_weight_bias.dim() == 2 and layer.w13_weight_bias.shape[0] == self.num_experts - and layer.w13_weight_bias.shape[1] == self.intermediate_size * 2 + and layer.w13_weight_bias.shape[1] + == self.intermediate_size_per_partition * 2 ) assert ( layer.w2_weight_bias.dim() == 2 @@ -511,7 +519,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): torch.stack(gemm1_scales_mxfp4_shuffled) .reshape( self.num_experts, - 2 * self.intermediate_size, + 2 * self.intermediate_size_per_partition, self.hidden_size // sf_block_size, ) .view(torch.float8_e4m3fn) @@ -523,7 +531,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): .reshape( self.num_experts, self.hidden_size, - self.intermediate_size // sf_block_size, + self.intermediate_size_per_partition // sf_block_size, ) .view(torch.float8_e4m3fn) ) @@ -613,16 +621,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): return tile_tokens_dim + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) + def apply( self, layer: torch.nn.Module, - x: torch.Tensor, - topk_output: TopKOutput, - moe_runner_config: MoeRunnerConfig, - ) -> torch.Tensor: + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput from sglang.srt.layers.moe.topk import TopKOutputChecker + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + + moe_runner_config = self.moe_runner_config + if self.use_flashinfer: # When bf16 mode is enabled, we don't need to quantize the input, # TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations, @@ -674,7 +692,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): top_k, None, # n_group # TODO: support n_group None, # topk_group # TODO: support topk_group - self.intermediate_size, # padded to multiple of 256 + self.intermediate_size_per_partition, # padded to multiple of 256 layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset layer.num_local_experts, # local num experts None, @@ -682,14 +700,14 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): 1, # routing_method_type, renormalize True, # do finalize )[0] - return trtllm_gen_output + return StandardCombineInput(hidden_states=trtllm_gen_output) if self.use_triton_kernels: assert ( layer.moe_ep_size == 1 ), "Expert parallel is not supported when using triton kernels" if self.with_bias: - return self.triton_kernel_moe_with_bias_forward( + output = self.triton_kernel_moe_with_bias_forward( hidden_states=x, w1=self.w13_weight_triton_tensor, w1_pcg=self.w13_precision_config, @@ -701,25 +719,22 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): moe_runner_config=moe_runner_config, ) else: - return self.triton_kernel_moe_forward( + output = self.triton_kernel_moe_forward( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_output=topk_output, moe_runner_config=moe_runner_config, ) + return StandardCombineInput(hidden_states=output) else: - from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts - - return fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_output=topk_output, - moe_runner_config=moe_runner_config, - b1=layer.w13_weight_bias, - b2=layer.w2_weight_bias, + quant_info = TritonMoeQuantInfo( + w13_weight=layer.w13_weight, + w2_weight=layer.w2_weight, + w13_weight_bias=layer.w13_weight_bias, + w2_weight_bias=layer.w2_weight_bias, ) + return self.runner.run(dispatch_output, quant_info) class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase): @@ -798,7 +813,7 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase): return w, mx_scales - def process_weights_after_loading(self, layer: Module) -> None: + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: w13, w13_mx_scales = self.mxfp4_quantize(layer.w13_weight.data) w2, w2_mx_scales = self.mxfp4_quantize(layer.w2_weight.data) @@ -808,19 +823,27 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase): layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False) layer.w2_weight_scale = torch.nn.Parameter(w2_mx_scales, requires_grad=False) + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + def apply( self, layer: torch.nn.Module, - x: torch.Tensor, - topk_output: TopKOutput, - moe_runner_config: MoeRunnerConfig, - ) -> torch.Tensor: + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + topk_weights, topk_ids, _ = topk_output if _is_hip: topk_weights = topk_weights.to( torch.float32 ) # aiter's moe_sorting requires topk_weights to be FP32 - return fused_moe( + output = fused_moe( x, layer.w13_weight, layer.w2_weight, @@ -831,8 +854,9 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase): w2_scale=layer.w2_weight_scale, activation=( ActivationType.Silu - if moe_runner_config.activation == "silu" + if self.moe_runner_config.activation == "silu" else ActivationType.Gelu ), doweight_stage1=False, ) + return StandardCombineInput(hidden_states=output) diff --git a/python/sglang/srt/layers/quantization/quark/quark_moe.py b/python/sglang/srt/layers/quantization/quark/quark_moe.py index 194fa414d..f6e750a2c 100644 --- a/python/sglang/srt/layers/quantization/quark/quark_moe.py +++ b/python/sglang/srt/layers/quantization/quark/quark_moe.py @@ -10,8 +10,17 @@ from aiter import ActivationType, QuantType, biased_grouped_topk from aiter.fused_moe import fused_moe from aiter.utility.fp4_utils import e8m0_shuffle +from sglang.srt.layers.moe import MoeRunnerConfig +from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase from sglang.srt.utils import get_bool_env_var, mxfp_supported, set_weight_attrs +if TYPE_CHECKING: + from sglang.srt.layers.moe.token_dispatcher import ( + CombineInput, + StandardDispatchOutput, + ) + from sglang.srt.layers.quantization.quark.quark import QuarkConfig + logger = logging.getLogger(__name__) __all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"] @@ -19,31 +28,17 @@ __all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"] OCP_MX_BLOCK_SIZE = 32 if TYPE_CHECKING: - from sglang.srt.layers.moe.topk import TopKOutput + from sglang.srt.layers.quantization import QuarkConfig -class QuarkMoEMethod: - def __new__(cls, *args, **kwargs): - from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase +class QuarkMoEMethod(FusedMoEMethodBase): - if not hasattr(cls, "_initialized"): - original_init = cls.__init__ - new_cls = type( - cls.__name__, - (FusedMoEMethodBase,), - { - "__init__": original_init, - **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, - }, - ) - obj = super(new_cls, new_cls).__new__(new_cls) - obj.__init__(*args, **kwargs) - return obj - return super().__new__(cls) + def __init__(self, quant_config: QuarkConfig): + self.quant_config = quant_config @staticmethod def get_moe_method( - quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821 + quant_config: QuarkConfig, # type: ignore # noqa E501 # noqa F821 module: torch.nn.Module, layer_name: str, ) -> "QuarkMoEMethod": @@ -170,16 +165,25 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): # layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, requires_grad=False) layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1) + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + def apply( self, layer: torch.nn.Module, - x: torch.Tensor, - topk_output: TopKOutput, - moe_runner_config: MoeRunnerConfig, - ) -> torch.Tensor: + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + moe_runner_config = self.moe_runner_config topk_weights, topk_ids, _ = topk_output - return fused_moe( + output = fused_moe( x, layer.w13_weight, layer.w2_weight, @@ -195,3 +199,4 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): ), doweight_stage1=False, ) + return StandardCombineInput(hidden_states=output) diff --git a/python/sglang/srt/layers/quantization/unquant.py b/python/sglang/srt/layers/quantization/unquant.py index 101bfe4f1..7a393748b 100644 --- a/python/sglang/srt/layers/quantization/unquant.py +++ b/python/sglang/srt/layers/quantization/unquant.py @@ -9,6 +9,8 @@ from torch.nn.parameter import Parameter from sglang.srt.custom_op import CustomOp from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading +from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig +from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo from sglang.srt.layers.quantization.base_config import ( FusedMoEMethodBase, LinearMethodBase, @@ -24,8 +26,10 @@ from sglang.srt.utils import ( ) if TYPE_CHECKING: - from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig - from sglang.srt.layers.moe.topk import TopKOutput + from sglang.srt.layers.moe.token_dispatcher import ( + CombineInput, + StandardDispatchOutput, + ) has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None @@ -155,7 +159,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): layer: torch.nn.Module, num_experts: int, hidden_size: int, - intermediate_size: int, + intermediate_size_per_partition: int, params_dtype: torch.dtype, with_bias: bool = False, **extra_weight_attrs, @@ -163,7 +167,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): self.with_bias = with_bias # Fused gate_up_proj (column parallel) - w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size + w13_weight_n, w13_weight_k = 2 * intermediate_size_per_partition, hidden_size if self.use_triton_kernels: w13_weight_n, w13_weight_k = w13_weight_k, w13_weight_n w13_weight = torch.nn.Parameter( @@ -175,7 +179,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): if self.with_bias: w13_weight_bias = torch.nn.Parameter( - torch.empty(num_experts, 2 * intermediate_size, dtype=torch.float32), + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + dtype=torch.float32, + ), requires_grad=False, ) layer.register_parameter("w13_weight_bias", w13_weight_bias) @@ -184,7 +192,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): # down_proj (row parallel) w2_weight_n, w2_weight_k = ( hidden_size, - intermediate_size, + intermediate_size_per_partition, ) if self.use_triton_kernels: w2_weight_n, w2_weight_k = w2_weight_k, w2_weight_n @@ -222,33 +230,40 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): return + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) + def apply( self, layer: torch.nn.Module, - x: torch.Tensor, - topk_output: TopKOutput, - moe_runner_config: MoeRunnerConfig, - ) -> torch.Tensor: + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: return self.forward( - x=x, layer=layer, - topk_output=topk_output, - moe_runner_config=moe_runner_config, + dispatch_output=dispatch_output, ) def forward_cuda( self, layer: torch.nn.Module, - x: torch.Tensor, - topk_output: TopKOutput, - moe_runner_config: MoeRunnerConfig, - ) -> torch.Tensor: + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + + moe_runner_config = self.moe_runner_config if self.use_triton_kernels: if self.with_bias: assert self.triton_kernel_moe_with_bias_forward is not None - return self.triton_kernel_moe_with_bias_forward( + output = self.triton_kernel_moe_with_bias_forward( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -261,13 +276,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ) else: assert self.triton_kernel_moe_forward is not None - return self.triton_kernel_moe_forward( + output = self.triton_kernel_moe_forward( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_output=topk_output, moe_runner_config=moe_runner_config, ) + return StandardCombineInput(hidden_states=output) else: if _use_aiter: assert not moe_runner_config.no_combine, "unsupported" @@ -284,7 +300,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): topk_weights = torch.ones_like( topk_weights, dtype=torch.float32 ) # topk_weights must be FP32 (float32) - return fused_moe( + output = fused_moe( x, layer.w13_weight, layer.w2_weight, @@ -296,28 +312,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): else ActivationType.Gelu ), ) + return StandardCombineInput(hidden_states=output) else: - from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( - fused_experts, - ) - return fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - b1=getattr(layer, "w13_weight_bias", None), + quant_info = TritonMoeQuantInfo( + w13_weight=layer.w13_weight, + w2_weight=layer.w2_weight, + b13=getattr(layer, "w13_weight_bias", None), b2=getattr(layer, "w2_weight_bias", None), - topk_output=topk_output, - moe_runner_config=moe_runner_config, ) + return self.runner.run(dispatch_output, quant_info) def forward_cpu( self, layer: torch.nn.Module, - x: torch.Tensor, - topk_output: TopKOutput, - moe_runner_config: MoeRunnerConfig, - ) -> torch.Tensor: + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + + moe_runner_config = self.moe_runner_config + assert ( moe_runner_config.activation == "silu" ), f"activation = {moe_runner_config.activation} is not supported." @@ -332,7 +350,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): x, topk_weights = apply_topk_weights_cpu( moe_runner_config.apply_router_weight_on_input, topk_weights, x ) - return torch.ops.sgl_kernel.fused_experts_cpu( + output = torch.ops.sgl_kernel.fused_experts_cpu( x, layer.w13_weight, layer.w2_weight, @@ -348,33 +366,39 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): None, # a2_scale True, # is_vnni ) + return StandardCombineInput(hidden_states=output) else: from sglang.srt.layers.moe.fused_moe_native import moe_forward_native - return moe_forward_native( + output = moe_forward_native( layer, x, topk_output, moe_runner_config, ) + return StandardCombineInput(hidden_states=output) def forward_npu( self, layer: torch.nn.Module, - x: torch.Tensor, - topk_output: TopKOutput, - moe_runner_config: MoeRunnerConfig, - ) -> torch.Tensor: - from sglang.srt.layers.moe.fused_moe_native import moe_forward_native + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: - return moe_forward_native( + from sglang.srt.layers.moe.fused_moe_native import moe_forward_native + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + + output = moe_forward_native( layer, x, topk_output, - moe_runner_config, + self.moe_runner_config, ) + return StandardCombineInput(hidden_states=output) - def forward_tpu(self, *args, **kwargs) -> torch.Tensor: + def forward_tpu(self, *args, **kwargs) -> CombineInput: raise NotImplementedError("The TPU backend currently does not support MoE.") forward_native = forward_cpu diff --git a/python/sglang/srt/layers/quantization/w4afp8.py b/python/sglang/srt/layers/quantization/w4afp8.py index a1cdc6cba..f39acd8af 100644 --- a/python/sglang/srt/layers/quantization/w4afp8.py +++ b/python/sglang/srt/layers/quantization/w4afp8.py @@ -9,6 +9,7 @@ from torch.nn.parameter import Parameter from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod +from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe from sglang.srt.layers.quantization.base_config import ( FusedMoEMethodBase, QuantizationConfig, @@ -22,7 +23,10 @@ from sglang.srt.utils import set_weight_attrs if TYPE_CHECKING: from sglang.srt.layers.moe import MoeRunnerConfig from sglang.srt.layers.moe.ep_moe.layer import EPMoE - from sglang.srt.layers.moe.topk import StandardTopKOutput + from sglang.srt.layers.moe.token_dispatcher import ( + CombineInput, + StandardDispatchOutput, + ) ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -133,7 +137,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): layer: EPMoE, num_experts: int, hidden_size: int, - intermediate_size: int, + intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs, ): @@ -145,7 +149,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): w13_weight = torch.nn.Parameter( torch.empty( num_experts, - intermediate_size * 2, + intermediate_size_per_partition * 2, hidden_size // 2, dtype=torch.int8, ), @@ -159,7 +163,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): torch.empty( num_experts, hidden_size, - intermediate_size // 2, + intermediate_size_per_partition // 2, dtype=torch.int8, ), requires_grad=False, @@ -173,7 +177,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): w13_weight_scale = torch.nn.Parameter( torch.zeros( num_experts, - 2 * intermediate_size, + 2 * intermediate_size_per_partition, hidden_size // self.quant_config.group_size, dtype=torch.float32, ), @@ -186,7 +190,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): torch.zeros( num_experts, hidden_size, - intermediate_size // self.quant_config.group_size, + intermediate_size_per_partition // self.quant_config.group_size, dtype=torch.float32, ), requires_grad=False, @@ -220,13 +224,13 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ) self.c_strides1 = torch.full( (num_experts, 3), - 2 * intermediate_size, + 2 * intermediate_size_per_partition, device=device, dtype=torch.int64, ) self.a_strides2 = torch.full( (num_experts, 3), - intermediate_size, + intermediate_size_per_partition, device=device, dtype=torch.int64, ) @@ -282,16 +286,21 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ) layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False) + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + def apply( self, layer: EPMoE, - x: torch.Tensor, - topk_output: StandardTopKOutput, - moe_runner_config: MoeRunnerConfig, - ) -> torch.Tensor: + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: - # TODO(ch-wan): move it out of this class - from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output topk_weights, topk_ids, _ = topk_output local_topk_ids = topk_ids @@ -328,6 +337,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): layer.w13_input_scale, layer.w2_input_scale, ) - if moe_runner_config.routed_scaling_factor is not None: - output *= moe_runner_config.routed_scaling_factor - return output + if self.moe_runner_config.routed_scaling_factor is not None: + output *= self.moe_runner_config.routed_scaling_factor + return StandardCombineInput(hidden_states=output) diff --git a/python/sglang/srt/layers/quantization/w8a8_fp8.py b/python/sglang/srt/layers/quantization/w8a8_fp8.py index 5e1aa41a6..c68591fc6 100644 --- a/python/sglang/srt/layers/quantization/w8a8_fp8.py +++ b/python/sglang/srt/layers/quantization/w8a8_fp8.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional import torch from torch.nn.parameter import Parameter +from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter from sglang.srt.layers.quantization.base_config import ( FusedMoEMethodBase, @@ -26,8 +27,11 @@ from sglang.srt.layers.quantization.fp8_utils import ( from sglang.srt.utils import set_weight_attrs if TYPE_CHECKING: - from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig - from sglang.srt.layers.moe.topk import StandardTopKOutput + from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig + from sglang.srt.layers.moe.token_dispatcher import ( + CombineInput, + StandardDispatchOutput, + ) _is_fp8_fnuz = is_fp8_fnuz() @@ -209,7 +213,7 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase): layer: torch.nn.Module, num_experts: int, hidden_size: int, - intermediate_size: int, + intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs, ): @@ -218,7 +222,10 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase): # WEIGHTS w13_weight = torch.nn.Parameter( torch.empty( - num_experts, 2 * intermediate_size, hidden_size, dtype=fp8_dtype + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=fp8_dtype, ), requires_grad=False, ) @@ -226,14 +233,21 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase): set_weight_attrs(w13_weight, extra_weight_attrs) w2_weight = torch.nn.Parameter( - torch.empty(num_experts, hidden_size, intermediate_size, dtype=fp8_dtype), + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=fp8_dtype, + ), requires_grad=False, ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) w13_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32), + torch.ones( + num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32 + ), requires_grad=False, ) w2_weight_scale = torch.nn.Parameter( @@ -266,25 +280,26 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase): layer.w2_weight_scale.data, requires_grad=False ) + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) + def apply( self, layer: torch.nn.Module, - x: torch.Tensor, - topk_output: StandardTopKOutput, - moe_runner_config: MoeRunnerConfig, - ) -> torch.Tensor: - from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: - return fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_output=topk_output, - moe_runner_config=moe_runner_config, + quant_info = TritonMoeQuantInfo( + w13_weight=layer.w13_weight, + w2_weight=layer.w2_weight, use_fp8_w8a8=True, per_channel_quant=True, - w1_scale=(layer.w13_weight_scale), - w2_scale=(layer.w2_weight_scale), - a1_scale=layer.w13_input_scale, + w13_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a13_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, ) + return self.runner.run(dispatch_output, quant_info) diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index db9bdbec9..0d76f99a4 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -24,6 +24,8 @@ from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, ) from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading +from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig +from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo from sglang.srt.layers.parameter import ( ChannelQuantScaleParameter, ModelWeightParameter, @@ -49,8 +51,10 @@ from sglang.srt.utils import ( ) if TYPE_CHECKING: - from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig - from sglang.srt.layers.moe.topk import TopKOutput + from sglang.srt.layers.moe.token_dispatcher import ( + CombineInput, + StandardDispatchOutput, + ) _is_cuda = is_cuda() _is_cpu_amx_available = cpu_has_amx_support() @@ -417,7 +421,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase): layer: torch.nn.Module, num_experts: int, hidden_size: int, - intermediate_size: int, + intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs, ): @@ -428,7 +432,10 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase): # WEIGHTS w13_weight = torch.nn.Parameter( torch.empty( - num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8 + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=torch.int8, ), requires_grad=False, ) @@ -436,14 +443,21 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase): set_weight_attrs(w13_weight, extra_weight_attrs) w2_weight = torch.nn.Parameter( - torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8), + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=torch.int8, + ), requires_grad=False, ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) w13_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32), + torch.ones( + num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32 + ), requires_grad=False, ) w2_weight_scale = torch.nn.Parameter( @@ -483,23 +497,30 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase): layer.w2_weight_scale.data, requires_grad=False ) + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) + def apply( self, layer: torch.nn.Module, - x: torch.Tensor, - topk_output: TopKOutput, - moe_runner_config: MoeRunnerConfig, + dispatch_output: StandardDispatchOutput, ) -> torch.Tensor: - from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output if use_intel_amx_backend(layer): from sglang.srt.layers.moe.topk import apply_topk_weights_cpu topk_weights, topk_ids, _ = topk_output x, topk_weights = apply_topk_weights_cpu( - moe_runner_config.apply_router_weight_on_input, topk_weights, x + self.moe_runner_config.apply_router_weight_on_input, topk_weights, x ) - return torch.ops.sgl_kernel.fused_experts_cpu( + output = torch.ops.sgl_kernel.fused_experts_cpu( x, layer.w13_weight, layer.w2_weight, @@ -515,20 +536,19 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase): layer.w2_input_scale, # a2_scale True, # is_vnni ) + return StandardCombineInput(hidden_states=output) - return fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_output=topk_output, - moe_runner_config=moe_runner_config, + quant_info = TritonMoeQuantInfo( + w13_weight=layer.w13_weight, + w2_weight=layer.w2_weight, use_int8_w8a8=True, per_channel_quant=True, - w1_scale=(layer.w13_weight_scale), - w2_scale=(layer.w2_weight_scale), - a1_scale=layer.w13_input_scale, + w13_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a13_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, ) + return self.runner.run(dispatch_output, quant_info) class NPU_W8A8LinearMethodImpl: @@ -900,7 +920,7 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase): layer: torch.nn.Module, num_experts: int, hidden_size: int, - intermediate_size: int, + intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs, ) -> None: @@ -914,21 +934,31 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase): # weight w13_weight = torch.nn.Parameter( torch.empty( - num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8 + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=torch.int8, ), requires_grad=False, ) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) w2_weight = torch.nn.Parameter( - torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8), + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=torch.int8, + ), requires_grad=False, ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) # scale w13_weight_scale = torch.nn.Parameter( - torch.empty(num_experts, 2 * intermediate_size, 1, dtype=torch.float32), + torch.empty( + num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32 + ), requires_grad=False, ) layer.register_parameter("w13_weight_scale", w13_weight_scale) @@ -941,7 +971,9 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase): set_weight_attrs(w2_weight_scale, extra_weight_attrs) # offset w13_weight_offset = torch.nn.Parameter( - torch.empty(num_experts, 2 * intermediate_size, 1, dtype=torch.float32), + torch.empty( + num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32 + ), requires_grad=False, ) layer.register_parameter("w13_weight_offset", w13_weight_offset) @@ -973,18 +1005,25 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase): layer.w2_weight_offset.data.squeeze(-1).contiguous(), requires_grad=False ) + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + def apply( self, layer, - x, - topk_output: TopKOutput, - moe_runner_config: MoeRunnerConfig, - ) -> torch.Tensor: + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output topk_weights, topk_ids, _ = topk_output topk_ids = topk_ids.to(torch.int32) topk_weights = topk_weights.to(x.dtype) - return npu_fused_experts( + output = npu_fused_experts( hidden_states=x, w13=layer.w13_weight, w13_scale=layer.w13_weight_scale, @@ -994,3 +1033,4 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase): topk_ids=topk_ids, top_k=topk_ids.shape[1], ) + return StandardCombineInput(hidden_states=output) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index a35ba0253..fdef179a1 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -52,7 +52,6 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import ( ScheduleBatchDisaggregationDecodeMixin, ) from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank -from sglang.srt.layers.moe import is_tbo_enabled from sglang.srt.mem_cache.allocator import ( BaseTokenToKVPoolAllocator, SWATokenToKVPoolAllocator, diff --git a/python/sglang/srt/model_loader/__init__.py b/python/sglang/srt/model_loader/__init__.py index fa2386e3a..63f110204 100644 --- a/python/sglang/srt/model_loader/__init__.py +++ b/python/sglang/srt/model_loader/__init__.py @@ -1,16 +1,22 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/__init__.py +from __future__ import annotations + +from typing import TYPE_CHECKING + from torch import nn -from sglang.srt.configs.device_config import DeviceConfig -from sglang.srt.configs.load_config import LoadConfig -from sglang.srt.configs.model_config import ModelConfig from sglang.srt.model_loader.loader import BaseModelLoader, get_model_loader from sglang.srt.model_loader.utils import ( get_architecture_class_name, get_model_architecture, ) +if TYPE_CHECKING: + from sglang.srt.configs.device_config import DeviceConfig + from sglang.srt.configs.load_config import LoadConfig + from sglang.srt.configs.model_config import ModelConfig + def get_model( *, diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 1abfee2f4..d2b4c6bfc 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -1,5 +1,7 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/model_executor/model_loader/loader.py +from __future__ import annotations + # ruff: noqa: SIM117 import collections import concurrent @@ -14,7 +16,17 @@ import time from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager -from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generator, + Iterable, + List, + Optional, + Tuple, + cast, +) import huggingface_hub import numpy as np @@ -26,9 +38,7 @@ from tqdm.auto import tqdm from transformers import AutoModelForCausalLM from transformers.utils import SAFE_WEIGHTS_INDEX_NAME -from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig, LoadFormat -from sglang.srt.configs.model_config import ModelConfig from sglang.srt.connector import ( ConnectorType, create_remote_connector, @@ -39,7 +49,6 @@ from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) -from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.model_loader.utils import ( get_model_architecture, post_load_weights, @@ -70,6 +79,11 @@ from sglang.srt.utils import ( set_weight_attrs, ) +if TYPE_CHECKING: + from sglang.srt.configs.device_config import DeviceConfig + from sglang.srt.configs.model_config import ModelConfig + from sglang.srt.layers.quantization.base_config import QuantizationConfig + _is_npu = is_npu() diff --git a/python/sglang/test/test_cutlass_moe.py b/python/sglang/test/test_cutlass_moe.py index 4a67ab3b6..f6bc2b0b2 100755 --- a/python/sglang/test/test_cutlass_moe.py +++ b/python/sglang/test/test_cutlass_moe.py @@ -9,6 +9,7 @@ from transformers import AutoConfig from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8 from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig +from sglang.srt.layers.moe.topk import StandardTopKOutput # Copy from: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/utils.py @@ -152,14 +153,32 @@ def run_test(tp_size, batch_size, model_config, check=False): problem_sizes2, ) + topk_output = StandardTopKOutput( + topk_weights=topk_weights, + topk_ids=topk_ids, + router_logits=torch.randn( + (batch_size, topk), device=topk_weights.device, dtype=dtype + ), + ) + + moe_runner_config = MoeRunnerConfig( + num_experts=E, + topk=topk, + hidden_size=H, + shard_intermediate_size=I, + dtype=dtype, + block_shape=block_shape, + activation="silu", + inplace=False, + ) + # Note: Triton expects non-transposed weights - moe_config = MoeRunnerConfig(inplace=False) triton_lambda = lambda: fused_experts( x, w1, w2, - (topk_weights, topk_ids, "dummy"), - moe_config, + topk_output, + moe_runner_config, use_fp8_w8a8=True, w1_scale=w1_scale, w2_scale=w2_scale, @@ -224,8 +243,8 @@ def run_test(tp_size, batch_size, model_config, check=False): x, w1, # Original shape w2, # Original shape - (topk_weights, topk_ids, "dummy"), - moe_config, + topk_output, + moe_runner_config, use_fp8_w8a8=True, w1_scale=w1_scale, w2_scale=w2_scale, diff --git a/test/srt/test_mla_deepseek_v3.py b/test/srt/test_mla_deepseek_v3.py index 0ebb191fb..634100fdb 100644 --- a/test/srt/test_mla_deepseek_v3.py +++ b/test/srt/test_mla_deepseek_v3.py @@ -1,3 +1,4 @@ +import os import unittest from types import SimpleNamespace @@ -49,6 +50,42 @@ class TestMLADeepseekV3(CustomTestCase): self.assertGreater(metrics["accuracy"], 0.62) +class TestMLADeepseekV3DisableFusedFunc(CustomTestCase): + @classmethod + def setUpClass(cls): + os.environ["SGLANG_CI_DISABLE_MOE_FUSED_FUNC"] = "1" + cls.model = "lmsys/sglang-ci-dsv3-test" + cls.base_url = DEFAULT_URL_FOR_TEST + other_args = ["--trust-remote-code", "--chunked-prefill-size", "256"] + if is_cuda(): + other_args.extend(["--cuda-graph-max-bs", "2"]) + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.62) + + @unittest.skipIf(is_hip(), "FA is not available.") class TestMLADeepseekV3Fa3Fp8Kvcache(CustomTestCase): @classmethod