Files
sglang/python/sglang/srt/layers/moe/moe_runner/base.py

297 lines
9.1 KiB
Python

from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Optional, Tuple, TypeGuard
import torch
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
CombineInputFormat,
DispatchOutput,
DispatchOutputFormat,
)
from sglang.srt.layers.moe.utils import MoeA2ABackend, MoeRunnerBackend
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner.triton import (
TritonRunnerCore,
TritonRunnerInput,
TritonRunnerOutput,
)
@dataclass
class MoeRunnerConfig:
# MoE parameters
num_experts: Optional[int] = None
num_local_experts: Optional[int] = None
hidden_size: Optional[int] = None
intermediate_size_per_partition: Optional[int] = None
layer_id: Optional[int] = None
top_k: Optional[int] = None
num_fused_shared_experts: Optional[int] = None
params_dtype: Optional[torch.dtype] = None
# Runner configuration
activation: str = "silu"
apply_router_weight_on_input: bool = False
inplace: bool = True
no_combine: bool = False
routed_scaling_factor: Optional[float] = None
gemm1_alpha: Optional[float] = None
gemm1_clamp_limit: Optional[float] = None
@dataclass
class RunnerInput(ABC):
@property
@abstractmethod
def runner_backend(self) -> MoeRunnerBackend: ...
def runner_backend_is_triton(self) -> TypeGuard[TritonRunnerInput]:
return self.runner_backend == MoeRunnerBackend.TRITON
class RunnerOutput(ABC):
@property
@abstractmethod
def runner_backend(self) -> MoeRunnerBackend: ...
def runner_backend_is_triton(self) -> TypeGuard[TritonRunnerOutput]:
return self.runner_backend == MoeRunnerBackend.TRITON
@dataclass
class MoeQuantInfo(ABC):
"""Moe quantization data."""
pass
class MoeRunnerCore(ABC):
def __init__(self, config: MoeRunnerConfig):
self.config = config
@abstractmethod
def run(
self, runner_input: RunnerInput, quant_info: MoeQuantInfo, running_state: dict
) -> RunnerOutput:
pass
@property
@abstractmethod
def runner_backend(self) -> MoeRunnerBackend: ...
def runner_backend_is_triton(self) -> TypeGuard[TritonRunnerCore]:
return self.runner_backend == MoeRunnerBackend.TRITON
class FusedOpPool:
_fused_funcs: dict[str, Callable] = {}
@classmethod
def register_fused_func(
cls, a2a_backend_name: str, runner_backend_name: str, fused_func: Callable
):
key = (a2a_backend_name, runner_backend_name)
if key in cls._fused_funcs:
raise ValueError(
f"Fused function for {a2a_backend_name} to {runner_backend_name} is already registered."
)
assert MoeA2ABackend(
a2a_backend_name
), f"Invalid dispatch name: {a2a_backend_name}"
assert MoeRunnerBackend(
runner_backend_name
), f"Invalid runner name: {runner_backend_name}"
cls._fused_funcs[key] = fused_func
@classmethod
def get_fused_func(cls, dispatch_name: str, runner_name: str) -> Optional[Callable]:
key = (dispatch_name, runner_name)
fused_func = cls._fused_funcs.get(key)
return fused_func
class PermuteMethodPool:
_pre_permute_methods: dict[
Tuple[DispatchOutputFormat, MoeRunnerBackend], Callable
] = {}
_post_permute_methods: dict[
Tuple[MoeRunnerBackend, CombineInputFormat], Callable
] = {}
@classmethod
def register_pre_permute(
cls,
dispatch_output_name: str,
runner_backend_name: str,
permute_func: Callable,
):
"""
Register a customized pre-permute function for the given DispatchOutputFormat and MoeRunnerBackend.
:param dispatch_output_name: The DispatchOutputFormat name.
:param runner_backend_name: The MoeRunnerBackend name.
:param permute_func: The permute function to register.
"""
key = (dispatch_output_name, runner_backend_name)
if key in cls._pre_permute_methods:
raise ValueError(
f"Pre-permute method for {dispatch_output_name} to {runner_backend_name} is already registered."
)
assert DispatchOutputFormat(
dispatch_output_name
), f"Invalid dispatch output name: {dispatch_output_name}"
assert MoeRunnerBackend(
runner_backend_name
), f"Invalid runner backend name: {runner_backend_name}"
cls._pre_permute_methods[key] = permute_func
@classmethod
def register_post_permute(
cls,
runner_backend_name: str,
combine_input_name: str,
permute_func: Callable,
):
"""
Register a customized post-permute function for the given MoeRunnerBackend and CombineInputFormat.
:param runner_backend_name: The MoeRunnerBackend name.
:param combine_input_name: The CombineInputFormat name.
:param permute_func: The permute function to register.
"""
key = (runner_backend_name, combine_input_name)
if key in cls._post_permute_methods:
raise ValueError(
f"Post-permute method for {runner_backend_name} to {combine_input_name} is already registered."
)
assert MoeRunnerBackend(
runner_backend_name
), f"Invalid runner backend name: {runner_backend_name}"
assert CombineInputFormat(
combine_input_name
), f"Invalid combine input name: {combine_input_name}"
cls._post_permute_methods[key] = permute_func
@classmethod
def get_pre_permute(
cls,
dispatch_output_format: DispatchOutputFormat,
runner_input_format: MoeRunnerBackend,
) -> Callable:
"""
Retrieve the pre-permute function for the given DispatchOutputFormat and MoeRunnerBackend.
:param dispatch_output_format: The DispatchOutputFormat type.
:param runner_input_format: The MoeRunnerBackend type.
:return: The registered permute function or None if not found.
"""
key = (dispatch_output_format, runner_input_format)
pre_permute_func = cls._pre_permute_methods.get(key)
assert (
pre_permute_func is not None
), f"Pre-permute function for {dispatch_output_format} to {runner_input_format} is not registered"
return pre_permute_func
@classmethod
def get_post_permute(
cls,
runner_output_format: MoeRunnerBackend,
combine_input_format: CombineInputFormat,
) -> Callable:
"""
Retrieve the post-permute function for the given MoeRunnerBackend and CombineInputFormat.
:param runner_output_format: The MoeRunnerBackend type.
:param combine_input_format: The CombineInputFormat type.
:return: The registered permute function or None if not found.
"""
key = (runner_output_format, combine_input_format)
post_permute_func = cls._post_permute_methods.get(key)
assert (
post_permute_func is not None
), f"Post-permute function for {runner_output_format} to {combine_input_format} is not registered"
return post_permute_func
def register_fused_func(
a2a_backend_name: str,
runner_backend_name: str,
) -> Callable:
"""
Decorator to register a fused function for the given DispatchOutputFormat and MoeRunnerBackend.
:param a2a_backend_name: The A2A backend name.
:param runner_backend_name: The MoeRunnerBackend name.
:return: The decorator function.
"""
def decorator(fused_func: Callable):
FusedOpPool.register_fused_func(
a2a_backend_name, runner_backend_name, fused_func
)
return fused_func
return decorator
def register_pre_permute(
dispatch_output_name: str,
runner_backend_name: str,
) -> Callable:
"""
Decorator to register a pre-permute function for the given DispatchOutputFormat and MoeRunnerBackend.
:param dispatch_output_name: The DispatchOutputFormat name.
:param runner_backend_name: The MoeRunnerBackend name.
:return: The decorator function.
"""
def decorator(
permute_func: Callable[
[DispatchOutput, MoeQuantInfo, MoeRunnerConfig, dict], RunnerInput
]
) -> Callable:
PermuteMethodPool.register_pre_permute(
dispatch_output_name, runner_backend_name, permute_func
)
return permute_func
return decorator
def register_post_permute(
runner_backend_name: str,
combine_input_name: str,
) -> Callable:
"""
Decorator to register a post-permute function for the given MoeRunnerBackend and CombineInputFormat.
:param runner_backend_name: The MoeRunnerBackend name.
:param combine_input_name: The CombineInputFormat name.
:return: The decorator function.
"""
def decorator(
permute_func: Callable[
[RunnerOutput, MoeQuantInfo, MoeRunnerConfig, dict], CombineInput
]
) -> Callable:
PermuteMethodPool.register_post_permute(
runner_backend_name, combine_input_name, permute_func
)
return permute_func
return decorator