[7/N] MoE Refactor: the implementation of new framework (#9269)
This commit is contained in:
@@ -11,6 +11,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
@@ -19,17 +22,19 @@ from abc import ABC
|
||||
from collections import deque
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Type
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Type
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from sglang.srt.eplb.expert_location import ExpertLocationMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import Withable, get_bool_env_var
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.eplb.expert_location import ExpertLocationMetadata
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --------------------------------------- Entrypoint -----------------------------------------
|
||||
@@ -43,7 +48,7 @@ class ExpertDistributionRecorder(ABC):
|
||||
@staticmethod
|
||||
def init_new(
|
||||
server_args: ServerArgs,
|
||||
expert_location_metadata: "ExpertLocationMetadata",
|
||||
expert_location_metadata: ExpertLocationMetadata,
|
||||
rank: int,
|
||||
):
|
||||
if server_args.expert_distribution_recorder_mode is not None:
|
||||
@@ -118,7 +123,7 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
|
||||
def __init__(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
expert_location_metadata: "ExpertLocationMetadata",
|
||||
expert_location_metadata: ExpertLocationMetadata,
|
||||
rank: int,
|
||||
):
|
||||
self._server_args = server_args
|
||||
@@ -279,7 +284,7 @@ class _SinglePassGatherer(ABC):
|
||||
@staticmethod
|
||||
def init_new(
|
||||
server_args: ServerArgs,
|
||||
expert_location_metadata: "ExpertLocationMetadata",
|
||||
expert_location_metadata: ExpertLocationMetadata,
|
||||
rank: int,
|
||||
) -> "_SinglePassGatherer":
|
||||
if server_args.expert_distribution_recorder_mode == "per_token":
|
||||
@@ -307,7 +312,7 @@ class _SinglePassGatherer(ABC):
|
||||
|
||||
return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
|
||||
|
||||
def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int):
|
||||
def __init__(self, expert_location_metadata: ExpertLocationMetadata, rank: int):
|
||||
self._expert_location_metadata = expert_location_metadata
|
||||
self._rank = rank
|
||||
|
||||
@@ -346,7 +351,7 @@ class _DetailSinglePassGatherer(_SinglePassGatherer):
|
||||
def __init__(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
expert_location_metadata: "ExpertLocationMetadata",
|
||||
expert_location_metadata: ExpertLocationMetadata,
|
||||
rank: int,
|
||||
):
|
||||
super().__init__(expert_location_metadata, rank)
|
||||
@@ -561,7 +566,7 @@ class _Accumulator(ABC):
|
||||
@staticmethod
|
||||
def init_new(
|
||||
server_args: ServerArgs,
|
||||
expert_location_metadata: "ExpertLocationMetadata",
|
||||
expert_location_metadata: ExpertLocationMetadata,
|
||||
rank: int,
|
||||
) -> "_Accumulator":
|
||||
return _Accumulator.get_class(server_args)(
|
||||
@@ -580,7 +585,7 @@ class _Accumulator(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
expert_location_metadata: "ExpertLocationMetadata",
|
||||
expert_location_metadata: ExpertLocationMetadata,
|
||||
rank: int,
|
||||
):
|
||||
self._server_args = server_args
|
||||
|
||||
@@ -11,21 +11,26 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.eplb import eplb_algorithms
|
||||
from sglang.srt.model_loader import get_model_architecture
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunner, MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.utils import (
|
||||
DeepEPMode,
|
||||
MoeA2ABackend,
|
||||
@@ -17,6 +17,7 @@ from sglang.srt.layers.moe.utils import (
|
||||
__all__ = [
|
||||
"DeepEPMode",
|
||||
"MoeA2ABackend",
|
||||
"MoeRunner",
|
||||
"MoeRunnerConfig",
|
||||
"MoeRunnerBackend",
|
||||
"initialize_moe_config",
|
||||
|
||||
@@ -8,16 +8,18 @@ from torch.nn import functional as F
|
||||
|
||||
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.token_dispatcher import StandardDispatchOutput
|
||||
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
||||
|
||||
|
||||
def fused_moe_forward_native(
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: StandardTopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> torch.Tensor:
|
||||
|
||||
x, topk_output = dispatch_output
|
||||
moe_runner_config = layer.moe_runner_config
|
||||
|
||||
if moe_runner_config.apply_router_weight_on_input:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# NOTE: this file will be separated into sglang/srt/layers/moe/moe_runner/triton_utils.py
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/fused_moe.py
|
||||
|
||||
"""Fused MoE kernel."""
|
||||
@@ -6,13 +7,12 @@ from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
||||
from sglang.srt.utils import (
|
||||
cpu_has_amx_support,
|
||||
direct_register_custom_op,
|
||||
@@ -26,6 +26,9 @@ from .fused_moe_triton_config import get_config_dtype_str, try_get_optimal_moe_c
|
||||
from .fused_moe_triton_kernels import invoke_fused_moe_kernel, moe_sum_reduce_triton
|
||||
from .moe_align_block_size import moe_align_block_size
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
|
||||
@@ -23,8 +23,13 @@ from sglang.srt.layers.moe import (
|
||||
get_moe_runner_backend,
|
||||
should_use_flashinfer_trtllm_moe,
|
||||
)
|
||||
from sglang.srt.layers.moe.token_dispatcher.standard import (
|
||||
CombineInput,
|
||||
StandardDispatcher,
|
||||
)
|
||||
from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
@@ -152,16 +157,6 @@ class FusedMoE(torch.nn.Module):
|
||||
self.expert_map_cpu = None
|
||||
self.expert_map_gpu = None
|
||||
|
||||
self.moe_runner_config = MoeRunnerConfig(
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
inplace=inplace,
|
||||
no_combine=no_combine,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
gemm1_alpha=gemm1_alpha,
|
||||
gemm1_clamp_limit=gemm1_clamp_limit,
|
||||
)
|
||||
|
||||
enable_flashinfer_cutlass_moe = get_moe_runner_backend().is_flashinfer_cutlass()
|
||||
|
||||
if enable_flashinfer_cutlass_moe and quant_config is None:
|
||||
@@ -196,13 +191,6 @@ class FusedMoE(torch.nn.Module):
|
||||
self.use_presharded_weights = use_presharded_weights
|
||||
|
||||
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
|
||||
if quant_config is None:
|
||||
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
|
||||
self.use_triton_kernels
|
||||
)
|
||||
else:
|
||||
self.quant_method = quant_config.get_quant_method(self, prefix)
|
||||
assert self.quant_method is not None
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.use_flashinfer_mxfp4_moe = get_moe_runner_backend().is_flashinfer_mxfp4()
|
||||
@@ -213,12 +201,40 @@ class FusedMoE(torch.nn.Module):
|
||||
and self.use_flashinfer_mxfp4_moe
|
||||
):
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
self.moe_runner_config = MoeRunnerConfig(
|
||||
num_experts=num_experts,
|
||||
num_local_experts=self.num_local_experts,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size_per_partition=self.intermediate_size_per_partition,
|
||||
layer_id=layer_id,
|
||||
top_k=top_k,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
params_dtype=params_dtype,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
inplace=inplace,
|
||||
no_combine=no_combine,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
gemm1_alpha=gemm1_alpha,
|
||||
gemm1_clamp_limit=gemm1_clamp_limit,
|
||||
)
|
||||
|
||||
if quant_config is None:
|
||||
self.quant_method: FusedMoEMethodBase = UnquantizedFusedMoEMethod(
|
||||
self.use_triton_kernels
|
||||
)
|
||||
else:
|
||||
self.quant_method: FusedMoEMethodBase = quant_config.get_quant_method(
|
||||
self, prefix
|
||||
)
|
||||
assert self.quant_method is not None
|
||||
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
num_experts=self.num_local_experts,
|
||||
hidden_size=hidden_size,
|
||||
# FIXME: figure out which intermediate_size to use
|
||||
intermediate_size=self.intermediate_size_per_partition,
|
||||
intermediate_size_per_partition=self.intermediate_size_per_partition,
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=(
|
||||
@@ -229,6 +245,9 @@ class FusedMoE(torch.nn.Module):
|
||||
with_bias=with_bias,
|
||||
)
|
||||
|
||||
self.quant_method.create_moe_runner(self, self.moe_runner_config)
|
||||
self.dispatcher = StandardDispatcher()
|
||||
|
||||
def _load_per_tensor_weight_scale(
|
||||
self,
|
||||
shard_id: str,
|
||||
@@ -811,16 +830,17 @@ class FusedMoE(torch.nn.Module):
|
||||
elif TopKOutputChecker.format_is_triton_kernel(topk_output):
|
||||
raise NotImplementedError()
|
||||
|
||||
# Matrix multiply.
|
||||
with use_symmetric_memory(get_tp_group()) as sm:
|
||||
dispatch_output = self.dispatcher.dispatch(
|
||||
hidden_states=hidden_states, topk_output=topk_output
|
||||
)
|
||||
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
topk_output=topk_output,
|
||||
moe_runner_config=self.moe_runner_config,
|
||||
)
|
||||
sm.tag(final_hidden_states)
|
||||
# TODO: consider using symmetric memory
|
||||
combine_input = self.quant_method.apply(
|
||||
layer=self,
|
||||
dispatch_output=dispatch_output,
|
||||
)
|
||||
|
||||
final_hidden_states = self.dispatcher.combine(combine_input)
|
||||
|
||||
final_hidden_states = final_hidden_states[
|
||||
..., :origin_hidden_states_dim
|
||||
@@ -955,7 +975,6 @@ class FlashInferFusedMoE(FusedMoE):
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
topk_output=topk_output,
|
||||
moe_runner_config=self.moe_runner_config,
|
||||
)
|
||||
|
||||
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.moe_runner.runner import MoeRunner
|
||||
|
||||
__all__ = ["MoeRunnerConfig"]
|
||||
__all__ = ["MoeRunnerConfig", "MoeRunner"]
|
||||
|
||||
@@ -1,9 +1,41 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Callable, Optional, Tuple, TypeGuard
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
CombineInput,
|
||||
CombineInputFormat,
|
||||
DispatchOutput,
|
||||
DispatchOutputFormat,
|
||||
)
|
||||
from sglang.srt.layers.moe.utils import MoeA2ABackend, MoeRunnerBackend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.moe_runner.triton import (
|
||||
TritonRunnerCore,
|
||||
TritonRunnerInput,
|
||||
TritonRunnerOutput,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MoeRunnerConfig:
|
||||
|
||||
# MoE parameters
|
||||
num_experts: Optional[int] = None
|
||||
num_local_experts: Optional[int] = None
|
||||
hidden_size: Optional[int] = None
|
||||
intermediate_size_per_partition: Optional[int] = None
|
||||
layer_id: Optional[int] = None
|
||||
top_k: Optional[int] = None
|
||||
num_fused_shared_experts: Optional[int] = None
|
||||
params_dtype: Optional[torch.dtype] = None
|
||||
|
||||
# Runner configuration
|
||||
activation: str = "silu"
|
||||
apply_router_weight_on_input: bool = False
|
||||
inplace: bool = True
|
||||
@@ -11,3 +43,254 @@ class MoeRunnerConfig:
|
||||
routed_scaling_factor: Optional[float] = None
|
||||
gemm1_alpha: Optional[float] = None
|
||||
gemm1_clamp_limit: Optional[float] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunnerInput(ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def runner_backend(self) -> MoeRunnerBackend: ...
|
||||
|
||||
def runner_backend_is_triton(self) -> TypeGuard[TritonRunnerInput]:
|
||||
return self.runner_backend == MoeRunnerBackend.TRITON
|
||||
|
||||
|
||||
class RunnerOutput(ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def runner_backend(self) -> MoeRunnerBackend: ...
|
||||
|
||||
def runner_backend_is_triton(self) -> TypeGuard[TritonRunnerOutput]:
|
||||
return self.runner_backend == MoeRunnerBackend.TRITON
|
||||
|
||||
|
||||
@dataclass
|
||||
class MoeQuantInfo(ABC):
|
||||
"""Moe quantization data."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class MoeRunnerCore(ABC):
|
||||
|
||||
def __init__(self, config: MoeRunnerConfig):
|
||||
self.config = config
|
||||
|
||||
@abstractmethod
|
||||
def run(
|
||||
self, runner_input: RunnerInput, quant_info: MoeQuantInfo, running_state: dict
|
||||
) -> RunnerOutput:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def runner_backend(self) -> MoeRunnerBackend: ...
|
||||
|
||||
def runner_backend_is_triton(self) -> TypeGuard[TritonRunnerCore]:
|
||||
return self.runner_backend == MoeRunnerBackend.TRITON
|
||||
|
||||
|
||||
class FusedOpPool:
|
||||
|
||||
_fused_funcs: dict[str, Callable] = {}
|
||||
|
||||
@classmethod
|
||||
def register_fused_func(
|
||||
cls, a2a_backend_name: str, runner_backend_name: str, fused_func: Callable
|
||||
):
|
||||
key = (a2a_backend_name, runner_backend_name)
|
||||
if key in cls._fused_funcs:
|
||||
raise ValueError(
|
||||
f"Fused function for {a2a_backend_name} to {runner_backend_name} is already registered."
|
||||
)
|
||||
assert MoeA2ABackend(
|
||||
a2a_backend_name
|
||||
), f"Invalid dispatch name: {a2a_backend_name}"
|
||||
assert MoeRunnerBackend(
|
||||
runner_backend_name
|
||||
), f"Invalid runner name: {runner_backend_name}"
|
||||
cls._fused_funcs[key] = fused_func
|
||||
|
||||
@classmethod
|
||||
def get_fused_func(cls, dispatch_name: str, runner_name: str) -> Optional[Callable]:
|
||||
key = (dispatch_name, runner_name)
|
||||
fused_func = cls._fused_funcs.get(key)
|
||||
return fused_func
|
||||
|
||||
|
||||
class PermuteMethodPool:
|
||||
|
||||
_pre_permute_methods: dict[
|
||||
Tuple[DispatchOutputFormat, MoeRunnerBackend], Callable
|
||||
] = {}
|
||||
_post_permute_methods: dict[
|
||||
Tuple[MoeRunnerBackend, CombineInputFormat], Callable
|
||||
] = {}
|
||||
|
||||
@classmethod
|
||||
def register_pre_permute(
|
||||
cls,
|
||||
dispatch_output_name: str,
|
||||
runner_backend_name: str,
|
||||
permute_func: Callable,
|
||||
):
|
||||
"""
|
||||
Register a customized pre-permute function for the given DispatchOutputFormat and MoeRunnerBackend.
|
||||
|
||||
:param dispatch_output_name: The DispatchOutputFormat name.
|
||||
:param runner_backend_name: The MoeRunnerBackend name.
|
||||
:param permute_func: The permute function to register.
|
||||
"""
|
||||
key = (dispatch_output_name, runner_backend_name)
|
||||
if key in cls._pre_permute_methods:
|
||||
raise ValueError(
|
||||
f"Pre-permute method for {dispatch_output_name} to {runner_backend_name} is already registered."
|
||||
)
|
||||
assert DispatchOutputFormat(
|
||||
dispatch_output_name
|
||||
), f"Invalid dispatch output name: {dispatch_output_name}"
|
||||
assert MoeRunnerBackend(
|
||||
runner_backend_name
|
||||
), f"Invalid runner backend name: {runner_backend_name}"
|
||||
cls._pre_permute_methods[key] = permute_func
|
||||
|
||||
@classmethod
|
||||
def register_post_permute(
|
||||
cls,
|
||||
runner_backend_name: str,
|
||||
combine_input_name: str,
|
||||
permute_func: Callable,
|
||||
):
|
||||
"""
|
||||
Register a customized post-permute function for the given MoeRunnerBackend and CombineInputFormat.
|
||||
|
||||
:param runner_backend_name: The MoeRunnerBackend name.
|
||||
:param combine_input_name: The CombineInputFormat name.
|
||||
:param permute_func: The permute function to register.
|
||||
"""
|
||||
key = (runner_backend_name, combine_input_name)
|
||||
if key in cls._post_permute_methods:
|
||||
raise ValueError(
|
||||
f"Post-permute method for {runner_backend_name} to {combine_input_name} is already registered."
|
||||
)
|
||||
assert MoeRunnerBackend(
|
||||
runner_backend_name
|
||||
), f"Invalid runner backend name: {runner_backend_name}"
|
||||
assert CombineInputFormat(
|
||||
combine_input_name
|
||||
), f"Invalid combine input name: {combine_input_name}"
|
||||
cls._post_permute_methods[key] = permute_func
|
||||
|
||||
@classmethod
|
||||
def get_pre_permute(
|
||||
cls,
|
||||
dispatch_output_format: DispatchOutputFormat,
|
||||
runner_input_format: MoeRunnerBackend,
|
||||
) -> Callable:
|
||||
"""
|
||||
Retrieve the pre-permute function for the given DispatchOutputFormat and MoeRunnerBackend.
|
||||
|
||||
:param dispatch_output_format: The DispatchOutputFormat type.
|
||||
:param runner_input_format: The MoeRunnerBackend type.
|
||||
:return: The registered permute function or None if not found.
|
||||
"""
|
||||
key = (dispatch_output_format, runner_input_format)
|
||||
pre_permute_func = cls._pre_permute_methods.get(key)
|
||||
assert (
|
||||
pre_permute_func is not None
|
||||
), f"Pre-permute function for {dispatch_output_format} to {runner_input_format} is not registered"
|
||||
return pre_permute_func
|
||||
|
||||
@classmethod
|
||||
def get_post_permute(
|
||||
cls,
|
||||
runner_output_format: MoeRunnerBackend,
|
||||
combine_input_format: CombineInputFormat,
|
||||
) -> Callable:
|
||||
"""
|
||||
Retrieve the post-permute function for the given MoeRunnerBackend and CombineInputFormat.
|
||||
|
||||
:param runner_output_format: The MoeRunnerBackend type.
|
||||
:param combine_input_format: The CombineInputFormat type.
|
||||
:return: The registered permute function or None if not found.
|
||||
"""
|
||||
key = (runner_output_format, combine_input_format)
|
||||
post_permute_func = cls._post_permute_methods.get(key)
|
||||
assert (
|
||||
post_permute_func is not None
|
||||
), f"Post-permute function for {runner_output_format} to {combine_input_format} is not registered"
|
||||
return post_permute_func
|
||||
|
||||
|
||||
def register_fused_func(
|
||||
a2a_backend_name: str,
|
||||
runner_backend_name: str,
|
||||
) -> Callable:
|
||||
"""
|
||||
Decorator to register a fused function for the given DispatchOutputFormat and MoeRunnerBackend.
|
||||
|
||||
:param a2a_backend_name: The A2A backend name.
|
||||
:param runner_backend_name: The MoeRunnerBackend name.
|
||||
:return: The decorator function.
|
||||
"""
|
||||
|
||||
def decorator(fused_func: Callable):
|
||||
FusedOpPool.register_fused_func(
|
||||
a2a_backend_name, runner_backend_name, fused_func
|
||||
)
|
||||
return fused_func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def register_pre_permute(
|
||||
dispatch_output_name: str,
|
||||
runner_backend_name: str,
|
||||
) -> Callable:
|
||||
"""
|
||||
Decorator to register a pre-permute function for the given DispatchOutputFormat and MoeRunnerBackend.
|
||||
|
||||
:param dispatch_output_name: The DispatchOutputFormat name.
|
||||
:param runner_backend_name: The MoeRunnerBackend name.
|
||||
:return: The decorator function.
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
permute_func: Callable[
|
||||
[DispatchOutput, MoeQuantInfo, MoeRunnerConfig, dict], RunnerInput
|
||||
]
|
||||
) -> Callable:
|
||||
|
||||
PermuteMethodPool.register_pre_permute(
|
||||
dispatch_output_name, runner_backend_name, permute_func
|
||||
)
|
||||
return permute_func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def register_post_permute(
|
||||
runner_backend_name: str,
|
||||
combine_input_name: str,
|
||||
) -> Callable:
|
||||
"""
|
||||
Decorator to register a post-permute function for the given MoeRunnerBackend and CombineInputFormat.
|
||||
|
||||
:param runner_backend_name: The MoeRunnerBackend name.
|
||||
:param combine_input_name: The CombineInputFormat name.
|
||||
:return: The decorator function.
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
permute_func: Callable[
|
||||
[RunnerOutput, MoeQuantInfo, MoeRunnerConfig, dict], CombineInput
|
||||
]
|
||||
) -> Callable:
|
||||
PermuteMethodPool.register_post_permute(
|
||||
runner_backend_name, combine_input_name, permute_func
|
||||
)
|
||||
return permute_func
|
||||
|
||||
return decorator
|
||||
|
||||
84
python/sglang/srt/layers/moe/moe_runner/runner.py
Normal file
84
python/sglang/srt/layers/moe/moe_runner/runner.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sglang.srt.layers.moe.moe_runner.base import (
|
||||
FusedOpPool,
|
||||
MoeRunnerConfig,
|
||||
PermuteMethodPool,
|
||||
)
|
||||
from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore
|
||||
from sglang.srt.layers.moe.token_dispatcher.base import (
|
||||
CombineInput,
|
||||
CombineInputFormat,
|
||||
DispatchOutput,
|
||||
)
|
||||
from sglang.srt.layers.moe.utils import get_moe_a2a_backend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.moe_runner.base import MoeQuantInfo
|
||||
from sglang.srt.layers.moe.utils import MoeRunnerBackend
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MoeRunner:
|
||||
|
||||
def __init__(self, runner_backend: MoeRunnerBackend, config: MoeRunnerConfig):
|
||||
self.runner_backend = runner_backend
|
||||
self.config = config
|
||||
|
||||
self.fused_func = None
|
||||
|
||||
if runner_backend.is_triton():
|
||||
self.runner_core = TritonRunnerCore(config)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported runner backend: {runner_backend}")
|
||||
|
||||
a2a_backend_name = get_moe_a2a_backend().value
|
||||
runner_backend_name = runner_backend.value
|
||||
|
||||
self.fused_func = FusedOpPool.get_fused_func(
|
||||
a2a_backend_name, runner_backend_name
|
||||
)
|
||||
|
||||
SGLANG_CI_DISABLE_MOE_FUSED_FUNC = os.environ.get(
|
||||
"SGLANG_CI_DISABLE_MOE_FUSED_FUNC", "0"
|
||||
)
|
||||
if SGLANG_CI_DISABLE_MOE_FUSED_FUNC == "1":
|
||||
logger.info(
|
||||
"SGLANG_CI_DISABLE_MOE_FUSED_FUNC is set to 1, disabling fused func"
|
||||
)
|
||||
self.fused_func = None
|
||||
|
||||
def run(
|
||||
self, dispatch_output: DispatchOutput, quant_info: MoeQuantInfo
|
||||
) -> CombineInput:
|
||||
|
||||
if self.fused_func is not None:
|
||||
return self.fused_func(dispatch_output, quant_info, self.config)
|
||||
|
||||
dispatch_format = dispatch_output.format.value
|
||||
runner_format = self.runner_core.runner_backend.value
|
||||
self.pre_permute_func = PermuteMethodPool.get_pre_permute(
|
||||
dispatch_format, runner_format
|
||||
)
|
||||
|
||||
running_state = {}
|
||||
runner_input = self.pre_permute_func(
|
||||
dispatch_output, quant_info, self.config, running_state
|
||||
)
|
||||
runner_output = self.runner_core.run(runner_input, quant_info, running_state)
|
||||
|
||||
runner_format = self.runner_core.runner_backend.value
|
||||
combine_format = dispatch_output.format.value
|
||||
self.post_permute_func = PermuteMethodPool.get_post_permute(
|
||||
runner_format, combine_format
|
||||
)
|
||||
combine_input = self.post_permute_func(
|
||||
runner_output, quant_info, self.config, running_state
|
||||
)
|
||||
|
||||
return combine_input
|
||||
442
python/sglang/srt/layers/moe/moe_runner/triton.py
Normal file
442
python/sglang/srt/layers/moe/moe_runner/triton.py
Normal file
@@ -0,0 +1,442 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.moe.moe_runner.base import (
|
||||
MoeQuantInfo,
|
||||
MoeRunnerConfig,
|
||||
MoeRunnerCore,
|
||||
RunnerInput,
|
||||
RunnerOutput,
|
||||
register_fused_func,
|
||||
register_post_permute,
|
||||
register_pre_permute,
|
||||
)
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
StandardCombineInput,
|
||||
StandardDispatchOutput,
|
||||
)
|
||||
from sglang.srt.layers.moe.utils import MoeRunnerBackend
|
||||
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_cpu = is_cpu()
|
||||
_use_aiter = bool(int(os.getenv("SGLANG_MOE_USE_AITER", "0")))
|
||||
_MOE_PADDING_SIZE = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
|
||||
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import gelu_and_mul, silu_and_mul
|
||||
elif _is_cpu and _is_cpu_amx_available:
|
||||
pass
|
||||
elif _is_hip:
|
||||
from vllm import _custom_ops as vllm_ops # gelu_and_mul, silu_and_mul
|
||||
|
||||
if _use_aiter:
|
||||
try:
|
||||
from aiter import moe_sum
|
||||
except ImportError:
|
||||
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
|
||||
|
||||
|
||||
if _is_cuda or _is_hip:
|
||||
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
||||
|
||||
|
||||
@dataclass
|
||||
class TritonRunnerInput(RunnerInput):
|
||||
|
||||
hidden_states: torch.Tensor
|
||||
topk_weights: torch.Tensor
|
||||
topk_ids: torch.Tensor
|
||||
sorted_token_ids: torch.Tensor
|
||||
expert_ids: torch.Tensor
|
||||
num_tokens_post_padded: torch.Tensor
|
||||
|
||||
@property
|
||||
def runner_backend(self) -> MoeRunnerBackend:
|
||||
return MoeRunnerBackend.TRITON
|
||||
|
||||
|
||||
@dataclass
|
||||
class TritonRunnerOutput(RunnerOutput):
|
||||
|
||||
hidden_states: torch.Tensor
|
||||
|
||||
@property
|
||||
def runner_backend(self) -> MoeRunnerBackend:
|
||||
return MoeRunnerBackend.TRITON
|
||||
|
||||
|
||||
@dataclass
|
||||
class TritonMoeQuantInfo(MoeQuantInfo):
|
||||
w13_weight: torch.Tensor
|
||||
w2_weight: torch.Tensor
|
||||
b13: Optional[torch.Tensor] = None
|
||||
b2: Optional[torch.Tensor] = None
|
||||
use_fp8_w8a8: bool = False
|
||||
use_int8_w8a8: bool = False
|
||||
use_int8_w8a16: bool = False
|
||||
use_int4_w4a16: bool = False
|
||||
per_channel_quant: bool = False
|
||||
w13_scale: Optional[torch.Tensor] = None
|
||||
w2_scale: Optional[torch.Tensor] = None
|
||||
w13_zp: Optional[torch.Tensor] = None
|
||||
w2_zp: Optional[torch.Tensor] = None
|
||||
a13_scale: Optional[torch.Tensor] = None
|
||||
a2_scale: Optional[torch.Tensor] = None
|
||||
block_shape: Optional[List[int]] = None
|
||||
|
||||
|
||||
class TritonRunnerCore(MoeRunnerCore):
|
||||
|
||||
def __init__(self, config: MoeRunnerConfig):
|
||||
super().__init__(config)
|
||||
|
||||
def run(
|
||||
self,
|
||||
runner_input: TritonRunnerInput,
|
||||
quant_info: TritonMoeQuantInfo,
|
||||
running_state: dict,
|
||||
) -> TritonRunnerOutput:
|
||||
|
||||
# TODO: move these functions to the triton runner
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||
invoke_fused_moe_kernel,
|
||||
moe_sum_reduce_torch_compile,
|
||||
moe_sum_reduce_triton,
|
||||
swiglu_with_alpha_and_limit,
|
||||
)
|
||||
|
||||
hidden_states = runner_input.hidden_states
|
||||
topk_weights = runner_input.topk_weights
|
||||
topk_ids = runner_input.topk_ids
|
||||
sorted_token_ids = runner_input.sorted_token_ids
|
||||
expert_ids = runner_input.expert_ids
|
||||
num_tokens_post_padded = runner_input.num_tokens_post_padded
|
||||
|
||||
w13 = quant_info.w13_weight
|
||||
w2 = quant_info.w2_weight
|
||||
b13 = quant_info.b13
|
||||
b2 = quant_info.b2
|
||||
a13_scale = quant_info.a13_scale
|
||||
a2_scale = quant_info.a2_scale
|
||||
w13_scale = quant_info.w13_scale
|
||||
w2_scale = quant_info.w2_scale
|
||||
w13_zp = quant_info.w13_zp
|
||||
w2_zp = quant_info.w2_zp
|
||||
block_shape = quant_info.block_shape
|
||||
per_channel_quant = quant_info.per_channel_quant
|
||||
use_fp8_w8a8 = quant_info.use_fp8_w8a8
|
||||
use_int8_w8a8 = quant_info.use_int8_w8a8
|
||||
use_int8_w8a16 = quant_info.use_int8_w8a16
|
||||
use_int4_w4a16 = quant_info.use_int4_w4a16
|
||||
|
||||
activation = self.config.activation
|
||||
no_combine = self.config.no_combine
|
||||
inplace = self.config.inplace
|
||||
gemm1_alpha = self.config.gemm1_alpha
|
||||
gemm1_limit = self.config.gemm1_clamp_limit
|
||||
routed_scaling_factor = self.config.routed_scaling_factor
|
||||
apply_router_weight_on_input = self.config.apply_router_weight_on_input
|
||||
|
||||
M = hidden_states.shape[0]
|
||||
E, N, _ = w13.shape
|
||||
compute_type = (
|
||||
tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
|
||||
)
|
||||
|
||||
intermediate_cache1 = torch.empty(
|
||||
(M, topk_ids.shape[1], N),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
hidden_states,
|
||||
w13,
|
||||
b13,
|
||||
intermediate_cache1,
|
||||
a13_scale,
|
||||
w13_scale,
|
||||
w13_zp,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
apply_router_weight_on_input,
|
||||
topk_ids.shape[1],
|
||||
running_state["config"],
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
per_channel_quant=per_channel_quant,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
intermediate_cache2 = torch.empty(
|
||||
(M * topk_ids.shape[1], N // 2),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
if activation == "silu":
|
||||
if gemm1_alpha is not None:
|
||||
assert gemm1_limit is not None
|
||||
intermediate_cache2 = swiglu_with_alpha_and_limit(
|
||||
intermediate_cache1.view(-1, N),
|
||||
gemm1_alpha,
|
||||
gemm1_limit,
|
||||
)
|
||||
elif _is_cuda:
|
||||
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
||||
else:
|
||||
vllm_ops.silu_and_mul(
|
||||
intermediate_cache2, intermediate_cache1.view(-1, N)
|
||||
)
|
||||
elif activation == "gelu":
|
||||
assert gemm1_alpha is None, "gemm1_alpha is not supported for gelu"
|
||||
assert gemm1_limit is None, "gemm1_limit is not supported for gelu"
|
||||
if _is_cuda:
|
||||
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
||||
else:
|
||||
vllm_ops.gelu_and_mul(
|
||||
intermediate_cache2, intermediate_cache1.view(-1, N)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation: {activation=}")
|
||||
|
||||
intermediate_cache3 = torch.empty(
|
||||
(M, topk_ids.shape[1], w2.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
if no_combine:
|
||||
assert not inplace
|
||||
out_hidden_states = torch.empty(
|
||||
(M, topk_ids.shape[1], w2.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
elif inplace:
|
||||
out_hidden_states = hidden_states
|
||||
else:
|
||||
out_hidden_states = torch.empty_like(hidden_states)
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
intermediate_cache2,
|
||||
w2,
|
||||
b2,
|
||||
(
|
||||
intermediate_cache3
|
||||
if not no_combine and topk_ids.shape[1] != 1
|
||||
else out_hidden_states.unsqueeze(0)
|
||||
),
|
||||
a2_scale,
|
||||
w2_scale,
|
||||
w2_zp,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
not apply_router_weight_on_input,
|
||||
1,
|
||||
running_state["config"],
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
per_channel_quant=per_channel_quant,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
if routed_scaling_factor is None:
|
||||
routed_scaling_factor = 1.0
|
||||
|
||||
if no_combine:
|
||||
pass
|
||||
elif _is_cuda:
|
||||
if topk_ids.shape[1] == 1 and routed_scaling_factor == 1.0:
|
||||
pass # we write directly into out_hidden_states
|
||||
elif topk_ids.shape[1] == 2 and routed_scaling_factor == 1.0:
|
||||
torch.add(
|
||||
intermediate_cache3[:, 0],
|
||||
intermediate_cache3[:, 1],
|
||||
out=out_hidden_states,
|
||||
).squeeze(dim=1)
|
||||
else:
|
||||
# According to micro benchmark results, torch.compile can get better performance for small token.
|
||||
if M <= 32:
|
||||
moe_sum_reduce_torch_compile(
|
||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
out_hidden_states,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
else:
|
||||
moe_sum_reduce_triton(
|
||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
out_hidden_states,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
elif _is_hip:
|
||||
if _use_aiter:
|
||||
moe_sum(
|
||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
out_hidden_states,
|
||||
)
|
||||
else:
|
||||
vllm_ops.moe_sum(
|
||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
out_hidden_states,
|
||||
)
|
||||
else:
|
||||
vllm_ops.moe_sum(
|
||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
out_hidden_states,
|
||||
)
|
||||
|
||||
return TritonRunnerOutput(
|
||||
hidden_states=out_hidden_states,
|
||||
)
|
||||
|
||||
@property
|
||||
def runner_backend(self) -> MoeRunnerBackend:
|
||||
return MoeRunnerBackend.TRITON
|
||||
|
||||
|
||||
@register_fused_func("none", "triton")
|
||||
def fused_experts_none_to_triton(
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
quant_info: TritonMoeQuantInfo,
|
||||
runner_config: MoeRunnerConfig,
|
||||
) -> StandardCombineInput:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
|
||||
output = fused_experts(
|
||||
hidden_states=dispatch_output.hidden_states,
|
||||
w1=quant_info.w13_weight,
|
||||
w2=quant_info.w2_weight,
|
||||
topk_output=dispatch_output.topk_output,
|
||||
moe_runner_config=runner_config,
|
||||
b1=quant_info.b13,
|
||||
b2=quant_info.b2,
|
||||
use_fp8_w8a8=quant_info.use_fp8_w8a8,
|
||||
use_int8_w8a8=quant_info.use_int8_w8a8,
|
||||
use_int8_w8a16=quant_info.use_int8_w8a16,
|
||||
use_int4_w4a16=quant_info.use_int4_w4a16,
|
||||
per_channel_quant=quant_info.per_channel_quant,
|
||||
w1_scale=quant_info.w13_scale,
|
||||
w2_scale=quant_info.w2_scale,
|
||||
w1_zp=quant_info.w13_zp,
|
||||
w2_zp=quant_info.w2_zp,
|
||||
a1_scale=quant_info.a13_scale,
|
||||
a2_scale=quant_info.a2_scale,
|
||||
block_shape=quant_info.block_shape,
|
||||
)
|
||||
|
||||
return StandardCombineInput(
|
||||
hidden_states=output,
|
||||
)
|
||||
|
||||
|
||||
@register_pre_permute("standard", "triton")
|
||||
def pre_permute_standard_to_triton(
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
quant_info: TritonMoeQuantInfo,
|
||||
runner_config: MoeRunnerConfig,
|
||||
running_state: dict,
|
||||
) -> TritonRunnerInput:
|
||||
|
||||
# NOTE: this is dead code as a fused func for standard format is registered.
|
||||
# This is left here for testing and examples.
|
||||
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||
get_config_dtype_str,
|
||||
moe_align_block_size,
|
||||
try_get_optimal_moe_config,
|
||||
)
|
||||
from sglang.srt.layers.moe.topk import TopKOutputChecker
|
||||
|
||||
hidden_states, topk_output = dispatch_output
|
||||
|
||||
assert TopKOutputChecker.format_is_standard(topk_output)
|
||||
|
||||
num_tokens = hidden_states.shape[0]
|
||||
num_local_experts = runner_config.num_local_experts
|
||||
|
||||
if (
|
||||
not (quant_info.use_fp8_w8a8 or quant_info.use_int8_w8a8)
|
||||
or quant_info.block_shape is not None
|
||||
or _use_aiter
|
||||
):
|
||||
padding_size = 0
|
||||
else:
|
||||
padding_size = _MOE_PADDING_SIZE
|
||||
|
||||
config_dtype = get_config_dtype_str(
|
||||
use_fp8_w8a8=quant_info.use_fp8_w8a8,
|
||||
use_int8_w8a8=quant_info.use_int8_w8a8,
|
||||
use_int8_w8a16=quant_info.use_int8_w8a16,
|
||||
use_int4_w4a16=quant_info.use_int4_w4a16,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
get_config_func = functools.partial(
|
||||
try_get_optimal_moe_config,
|
||||
quant_info.w13_weight.shape,
|
||||
(
|
||||
num_local_experts,
|
||||
quant_info.w2_weight.shape[1],
|
||||
quant_info.w2_weight.shape[2] - padding_size,
|
||||
),
|
||||
topk_output.topk_ids.shape[1],
|
||||
config_dtype,
|
||||
block_shape=quant_info.block_shape,
|
||||
)
|
||||
|
||||
config = get_config_func(num_tokens)
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||
topk_output.topk_ids, config["BLOCK_SIZE_M"], num_local_experts
|
||||
)
|
||||
|
||||
running_state["config"] = config
|
||||
|
||||
return TritonRunnerInput(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_output.topk_weights,
|
||||
topk_ids=topk_output.topk_ids,
|
||||
sorted_token_ids=sorted_token_ids,
|
||||
expert_ids=expert_ids,
|
||||
num_tokens_post_padded=num_tokens_post_padded,
|
||||
)
|
||||
|
||||
|
||||
@register_post_permute("triton", "standard")
|
||||
def post_permute_triton_to_standard(
|
||||
runner_output: TritonRunnerOutput,
|
||||
quant_info: TritonMoeQuantInfo,
|
||||
runner_config: MoeRunnerConfig,
|
||||
running_state: dict,
|
||||
) -> StandardCombineInput:
|
||||
|
||||
# NOTE: this is dead code as a fused func for standard format is registered.
|
||||
# This is left here for testing and examples.
|
||||
|
||||
return StandardCombineInput(
|
||||
hidden_states=runner_output.hidden_states,
|
||||
)
|
||||
@@ -1,6 +1,9 @@
|
||||
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
|
||||
from sglang.srt.layers.moe.token_dispatcher.base import (
|
||||
BaseDispatcher,
|
||||
BaseDispatcherConfig,
|
||||
CombineInput,
|
||||
CombineInputChecker,
|
||||
CombineInputFormat,
|
||||
DispatchOutput,
|
||||
DispatchOutputChecker,
|
||||
DispatchOutputFormat,
|
||||
@@ -9,21 +12,32 @@ from sglang.srt.layers.moe.token_dispatcher.deepep import (
|
||||
AscendDeepEPLLOutput,
|
||||
DeepEPConfig,
|
||||
DeepEPDispatcher,
|
||||
DeepEPLLCombineInput,
|
||||
DeepEPLLOutput,
|
||||
DeepEPNormalCombineInput,
|
||||
DeepEPNormalOutput,
|
||||
)
|
||||
from sglang.srt.layers.moe.token_dispatcher.standard import StandardDispatchOutput
|
||||
from sglang.srt.layers.moe.token_dispatcher.standard import (
|
||||
StandardCombineInput,
|
||||
StandardDispatchOutput,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AscendDeepEPLLOutput",
|
||||
"BaseDispatcher",
|
||||
"BaseDispatcherConfig",
|
||||
"CombineInput",
|
||||
"CombineInputChecker",
|
||||
"CombineInputFormat",
|
||||
"DispatchOutput",
|
||||
"DispatchOutputFormat",
|
||||
"DispatchOutputChecker",
|
||||
"StandardDispatchOutput",
|
||||
"StandardCombineInput",
|
||||
"DeepEPConfig",
|
||||
"DeepEPDispatcher",
|
||||
"DeepEPNormalOutput",
|
||||
"DeepEPLLOutput",
|
||||
"DeepEPLLCombineInput",
|
||||
"DeepEPNormalCombineInput",
|
||||
]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum, auto
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Protocol, TypeGuard, Union, runtime_checkable
|
||||
|
||||
import torch
|
||||
@@ -9,10 +9,16 @@ import torch
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
AscendDeepEPLLOutput,
|
||||
DeepEPLLCombineInput,
|
||||
DeepEPLLOutput,
|
||||
DeepEPNormalCombineInput,
|
||||
DeepEPNormalOutput,
|
||||
StandardCombineInput,
|
||||
StandardDispatchOutput,
|
||||
)
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
|
||||
# ------------------------------ Dispatch Output -------------------------------------
|
||||
|
||||
|
||||
class DispatchOutputChecker:
|
||||
@@ -50,10 +56,10 @@ class DispatchOutputChecker:
|
||||
|
||||
class DispatchOutputFormat(Enum):
|
||||
|
||||
STANDARD = auto()
|
||||
DEEPEP_NORMAL = auto()
|
||||
DEEPEP_LL = auto()
|
||||
ASCENT_LL = auto()
|
||||
STANDARD = "standard"
|
||||
DEEPEP_NORMAL = "deepep_normal"
|
||||
DEEPEP_LL = "deepep_ll"
|
||||
ASCENT_LL = "ascent_ll"
|
||||
|
||||
def is_standard(self) -> bool:
|
||||
return self == DispatchOutputFormat.STANDARD
|
||||
@@ -78,10 +84,63 @@ class DispatchOutputFormat(Enum):
|
||||
class DispatchOutput(Protocol):
|
||||
"""Protocol for dispatch outputs in different formats."""
|
||||
|
||||
# TODO: add hidden_states to the protocol
|
||||
|
||||
@property
|
||||
def format(self) -> DispatchOutputFormat: ...
|
||||
|
||||
|
||||
# ------------------------------ Combine Input -------------------------------------
|
||||
|
||||
|
||||
class CombineInputChecker:
|
||||
@staticmethod
|
||||
def format_is_standard(
|
||||
combine_input: CombineInput,
|
||||
) -> TypeGuard[StandardCombineInput]:
|
||||
return combine_input.format == CombineInputFormat.STANDARD
|
||||
|
||||
@staticmethod
|
||||
def format_is_deepep_normal(
|
||||
combine_input: CombineInput,
|
||||
) -> TypeGuard[DeepEPNormalCombineInput]:
|
||||
return combine_input.format == CombineInputFormat.DEEPEP_NORMAL
|
||||
|
||||
@staticmethod
|
||||
def format_is_deepep_ll(
|
||||
combine_input: CombineInput,
|
||||
) -> TypeGuard[DeepEPLLCombineInput]:
|
||||
return combine_input.format == CombineInputFormat.DEEPEP_LL
|
||||
|
||||
@staticmethod
|
||||
def format_is_deepep(
|
||||
combine_input: CombineInput,
|
||||
) -> TypeGuard[Union[DeepEPNormalCombineInput, DeepEPLLCombineInput]]:
|
||||
return combine_input.format in [
|
||||
CombineInputFormat.DEEPEP_NORMAL,
|
||||
CombineInputFormat.DEEPEP_LL,
|
||||
]
|
||||
|
||||
|
||||
class CombineInputFormat(Enum):
|
||||
STANDARD = "standard"
|
||||
DEEPEP_NORMAL = "deepep_normal"
|
||||
DEEPEP_LL = "deepep_ll"
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class CombineInput(Protocol):
|
||||
"""Protocol for combine inputs in different formats."""
|
||||
|
||||
# TODO: add hidden_states to the protocol
|
||||
|
||||
@property
|
||||
def format(self) -> CombineInputFormat: ...
|
||||
|
||||
|
||||
# ------------------------------ Base Dispatcher -------------------------------------
|
||||
|
||||
|
||||
class BaseDispatcherConfig(ABC):
|
||||
"""Base class for dispatcher configs."""
|
||||
|
||||
@@ -92,9 +151,11 @@ class BaseDispatcher(ABC):
|
||||
"""Base class for dispatchers."""
|
||||
|
||||
@abstractmethod
|
||||
def dispatch(self, *args, **kwargs) -> DispatchOutput:
|
||||
def dispatch(
|
||||
self, hidden_states: torch.Tensor, topk_output: TopKOutput, **kwargs
|
||||
) -> DispatchOutput:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def combine(self, *args, **kwargs) -> torch.Tensor:
|
||||
def combine(self, combine_input: CombineInput, **kwargs) -> torch.Tensor:
|
||||
pass
|
||||
@@ -5,13 +5,15 @@ from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
|
||||
|
||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||
from sglang.srt.layers.moe import DeepEPMode, get_deepep_config, is_tbo_enabled
|
||||
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
|
||||
from sglang.srt.layers.moe.token_dispatcher.base import (
|
||||
BaseDispatcher,
|
||||
BaseDispatcherConfig,
|
||||
CombineInput,
|
||||
CombineInputFormat,
|
||||
DispatchOutput,
|
||||
DispatchOutputFormat,
|
||||
)
|
||||
from sglang.srt.layers.moe.utils import DeepEPMode, get_deepep_config, is_tbo_enabled
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
@@ -56,6 +58,7 @@ class DeepEPNormalOutput(NamedTuple):
|
||||
"""DeepEP normal dispatch output."""
|
||||
|
||||
hidden_states: torch.Tensor | Tuple[torch.Tensor, torch.Tensor]
|
||||
# hidden_states_scale
|
||||
topk_idx: torch.Tensor
|
||||
topk_weights: torch.Tensor
|
||||
num_recv_tokens_per_expert: List[int]
|
||||
@@ -99,6 +102,30 @@ assert isinstance(DeepEPLLOutput, DispatchOutput)
|
||||
assert isinstance(AscendDeepEPLLOutput, DispatchOutput)
|
||||
|
||||
|
||||
class DeepEPNormalCombineInput(NamedTuple):
|
||||
"""DeepEP normal combine input."""
|
||||
|
||||
pass
|
||||
|
||||
@property
|
||||
def format(self) -> CombineInputFormat:
|
||||
return CombineInputFormat.DEEPEP_NORMAL
|
||||
|
||||
|
||||
class DeepEPLLCombineInput(NamedTuple):
|
||||
"""DeepEP low latency combine input."""
|
||||
|
||||
pass
|
||||
|
||||
@property
|
||||
def format(self) -> CombineInputFormat:
|
||||
return CombineInputFormat.DEEPEP_LL
|
||||
|
||||
|
||||
assert isinstance(DeepEPNormalCombineInput, CombineInput)
|
||||
assert isinstance(DeepEPLLCombineInput, CombineInput)
|
||||
|
||||
|
||||
class DeepEPDispatchMode(IntEnum):
|
||||
NORMAL = auto()
|
||||
LOW_LATENCY = auto()
|
||||
|
||||
@@ -1,19 +1,61 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import NamedTuple
|
||||
from typing import TYPE_CHECKING, NamedTuple
|
||||
|
||||
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.moe.token_dispatcher.base import (
|
||||
BaseDispatcher,
|
||||
CombineInput,
|
||||
CombineInputFormat,
|
||||
DispatchOutput,
|
||||
DispatchOutputFormat,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
|
||||
|
||||
class StandardDispatchOutput(NamedTuple):
|
||||
"""Standard dispatch output."""
|
||||
|
||||
hidden_states: torch.Tensor
|
||||
topk_output: TopKOutput
|
||||
|
||||
@property
|
||||
def format(self) -> DispatchOutputFormat:
|
||||
return DispatchOutputFormat.STANDARD
|
||||
|
||||
|
||||
assert isinstance(StandardDispatchOutput, DispatchOutput)
|
||||
|
||||
|
||||
class StandardCombineInput(NamedTuple):
|
||||
"""Standard combine input."""
|
||||
|
||||
hidden_states: torch.Tensor
|
||||
|
||||
@property
|
||||
def format(self) -> CombineInputFormat:
|
||||
return CombineInputFormat.STANDARD
|
||||
|
||||
|
||||
assert isinstance(StandardCombineInput, CombineInput)
|
||||
|
||||
|
||||
class StandardDispatcher(BaseDispatcher):
|
||||
|
||||
def dispatch(
|
||||
self, hidden_states: torch.Tensor, topk_output: TopKOutput
|
||||
) -> DispatchOutput:
|
||||
return StandardDispatchOutput(
|
||||
hidden_states=hidden_states, topk_output=topk_output
|
||||
)
|
||||
|
||||
def combine(self, combine_input: CombineInput) -> torch.Tensor:
|
||||
if isinstance(combine_input, StandardCombineInput):
|
||||
return combine_input.hidden_states
|
||||
else:
|
||||
# TODO: this branch should be removed in the future
|
||||
assert isinstance(combine_input, torch.Tensor)
|
||||
return combine_input
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import logging
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
@@ -12,11 +13,12 @@ from sglang.srt.layers.dp_attention import (
|
||||
get_attention_dp_size,
|
||||
is_dp_attention_enabled,
|
||||
)
|
||||
from sglang.srt.utils import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MoeA2ABackend(Enum):
|
||||
|
||||
@@ -131,7 +133,7 @@ def get_moe_a2a_backend() -> MoeA2ABackend:
|
||||
global MOE_A2A_BACKEND
|
||||
if MOE_A2A_BACKEND is None:
|
||||
logger.warning("MOE_A2A_BACKEND is not initialized, using default backend")
|
||||
MOE_A2A_BACKEND = MoeA2ABackend(None)
|
||||
MOE_A2A_BACKEND = MoeA2ABackend.NONE
|
||||
return MOE_A2A_BACKEND
|
||||
|
||||
|
||||
@@ -139,7 +141,7 @@ def get_moe_runner_backend() -> MoeRunnerBackend:
|
||||
global MOE_RUNNER_BACKEND
|
||||
if MOE_RUNNER_BACKEND is None:
|
||||
logger.warning("MOE_RUNNER_BACKEND is not initialized, using triton backend")
|
||||
MOE_RUNNER_BACKEND = MoeRunnerBackend("triton")
|
||||
MOE_RUNNER_BACKEND = MoeRunnerBackend.AUTO
|
||||
return MOE_RUNNER_BACKEND
|
||||
|
||||
|
||||
@@ -147,7 +149,7 @@ def get_deepep_mode() -> DeepEPMode:
|
||||
global DEEPEP_MODE
|
||||
if DEEPEP_MODE is None:
|
||||
logger.warning("DEEPEP_MODE is not initialized, using auto mode")
|
||||
DEEPEP_MODE = DeepEPMode("auto")
|
||||
DEEPEP_MODE = DeepEPMode.AUTO
|
||||
return DEEPEP_MODE
|
||||
|
||||
|
||||
|
||||
@@ -34,7 +34,10 @@ from sglang.srt.layers.quantization.utils import get_scalar_types, replace_param
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
StandardDispatchOutput,
|
||||
CombineInput,
|
||||
)
|
||||
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
|
||||
@@ -736,24 +739,32 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
replace_parameter(layer, "w2_qzeros", marlin_w2_zp)
|
||||
|
||||
def create_moe_runner(
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
):
|
||||
self.moe_runner_config = moe_runner_config
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: StandardTopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> CombineInput:
|
||||
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
||||
|
||||
assert (
|
||||
moe_runner_config.activation == "silu"
|
||||
self.moe_runner_config.activation == "silu"
|
||||
), "Only SiLU activation is supported."
|
||||
|
||||
# The input must currently be float16
|
||||
x = dispatch_output.hidden_states
|
||||
topk_output = dispatch_output.topk_output
|
||||
|
||||
orig_dtype = x.dtype
|
||||
x = x.half()
|
||||
|
||||
topk_weights, topk_ids, router_logits = topk_output
|
||||
|
||||
return fused_marlin_moe(
|
||||
output = fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_qweight,
|
||||
layer.w2_qweight,
|
||||
@@ -768,3 +779,4 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
w2_zeros=layer.w2_qzeros,
|
||||
num_bits=self.quant_config.weight_bits,
|
||||
).to(orig_dtype)
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
|
||||
|
||||
import torch
|
||||
@@ -10,7 +11,7 @@ from torch import nn
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
from sglang.srt.layers.moe.token_dispatcher import CombineInput, DispatchOutput
|
||||
|
||||
|
||||
class QuantizeMethodBase(ABC):
|
||||
@@ -89,20 +90,24 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def create_moe_runner(
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
dispatch_output: DispatchOutput,
|
||||
) -> CombineInput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
@@ -9,6 +9,8 @@ import torch
|
||||
from torch.nn import Module
|
||||
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
||||
from sglang.srt.layers.parameter import BlockQuantScaleParameter, ModelWeightParameter
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
@@ -22,8 +24,10 @@ from sglang.srt.layers.quantization.utils import is_layer_skipped
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
CombineInput,
|
||||
StandardDispatchOutput,
|
||||
)
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
@@ -257,7 +261,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
|
||||
layer: Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
@@ -273,25 +277,28 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
|
||||
# Required by column parallel or enabling merged weights
|
||||
if intermediate_size % block_n != 0:
|
||||
if intermediate_size_per_partition % block_n != 0:
|
||||
raise ValueError(
|
||||
f"The output_size of gate's and up's weight = "
|
||||
f"{intermediate_size} is not divisible by "
|
||||
f"{intermediate_size_per_partition} is not divisible by "
|
||||
f"weight quantization block_n = {block_n}."
|
||||
)
|
||||
if tp_size > 1:
|
||||
# Required by row parallel
|
||||
if intermediate_size % block_k != 0:
|
||||
if intermediate_size_per_partition % block_k != 0:
|
||||
raise ValueError(
|
||||
f"The input_size of down's weight = "
|
||||
f"{intermediate_size} is not divisible by "
|
||||
f"{intermediate_size_per_partition} is not divisible by "
|
||||
f"weight quantization block_k = {block_k}."
|
||||
)
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
@@ -300,7 +307,10 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts, hidden_size, intermediate_size, dtype=params_dtype
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
@@ -311,7 +321,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
2 * ((intermediate_size + block_n - 1) // block_n),
|
||||
2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
|
||||
(hidden_size + block_k - 1) // block_k,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
@@ -321,7 +331,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
|
||||
torch.ones(
|
||||
num_experts,
|
||||
(hidden_size + block_n - 1) // block_n,
|
||||
(intermediate_size + block_k - 1) // block_k,
|
||||
(intermediate_size_per_partition + block_k - 1) // block_k,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
@@ -344,26 +354,27 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
|
||||
# Block quant doesn't need to process weights after loading
|
||||
return
|
||||
|
||||
def create_moe_runner(
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
):
|
||||
self.moe_runner_config = moe_runner_config
|
||||
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> CombineInput:
|
||||
|
||||
# Expert fusion with INT8 quantization
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_output=topk_output,
|
||||
moe_runner_config=moe_runner_config,
|
||||
quant_info = TritonMoeQuantInfo(
|
||||
w13_weight=layer.w13_weight,
|
||||
w2_weight=layer.w2_weight,
|
||||
use_int8_w8a8=True,
|
||||
w1_scale=(layer.w13_weight_scale_inv),
|
||||
w2_scale=(layer.w2_weight_scale_inv),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
w13_scale=layer.w13_weight_scale_inv,
|
||||
w2_scale=layer.w2_weight_scale_inv,
|
||||
a13_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
)
|
||||
|
||||
return self.runner.run(dispatch_output, quant_info)
|
||||
|
||||
@@ -11,6 +11,8 @@ import torch
|
||||
from compressed_tensors import CompressionFormat
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
|
||||
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
||||
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
|
||||
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
|
||||
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
||||
@@ -30,8 +32,10 @@ from sglang.srt.utils import (
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
CombineInput,
|
||||
StandardDispatchOutput,
|
||||
)
|
||||
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
|
||||
CompressedTensorsConfig,
|
||||
)
|
||||
@@ -293,14 +297,24 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def create_moe_runner(
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
):
|
||||
self.moe_runner_config = moe_runner_config
|
||||
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.fused_moe_triton import fused_experts
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> CombineInput:
|
||||
|
||||
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
||||
|
||||
x = dispatch_output.hidden_states
|
||||
topk_output = dispatch_output.topk_output
|
||||
|
||||
moe_runner_config = self.moe_runner_config
|
||||
|
||||
if (
|
||||
_use_aiter
|
||||
@@ -308,7 +322,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
and moe_runner_config.apply_router_weight_on_input
|
||||
):
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
return rocm_fused_experts_tkw1(
|
||||
output = rocm_fused_experts_tkw1(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
@@ -324,21 +338,20 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
)
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
else:
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_output=topk_output,
|
||||
moe_runner_config=moe_runner_config,
|
||||
quant_info = TritonMoeQuantInfo(
|
||||
w13_weight=layer.w13_weight,
|
||||
w2_weight=layer.w2_weight,
|
||||
use_fp8_w8a8=True,
|
||||
per_channel_quant=self.weight_quant.strategy
|
||||
== QuantizationStrategy.CHANNEL,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w13_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a13_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
)
|
||||
return self.runner.run(dispatch_output, quant_info)
|
||||
|
||||
|
||||
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
@@ -380,8 +393,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
params_dtype == torch.float16
|
||||
), "float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501
|
||||
|
||||
intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full")
|
||||
|
||||
# Will transpose the loaded weight along the
|
||||
# intermediate and hidden dim sizes. Will
|
||||
# shard for TP along the transposed dims
|
||||
@@ -415,13 +426,13 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
# In the case where we have actorder/g_idx,
|
||||
# we do not partition the w2 scales
|
||||
load_full_w2 = self.actorder and self.group_size != -1
|
||||
w2_scales_size = (
|
||||
intermediate_size_full if load_full_w2 else intermediate_size_per_partition
|
||||
)
|
||||
|
||||
self.is_k_full = (not self.actorder) or (
|
||||
intermediate_size_per_partition == intermediate_size_full
|
||||
)
|
||||
if load_full_w2:
|
||||
w2_scales_size = intermediate_size_per_partition * layer.moe_tp_size
|
||||
else:
|
||||
w2_scales_size = intermediate_size_per_partition
|
||||
|
||||
self.is_k_full = (not self.actorder) or layer.moe_tp_size == 1
|
||||
|
||||
if self.strategy == "channel":
|
||||
num_groups_w2 = num_groups_w13 = 1
|
||||
@@ -640,21 +651,29 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
)
|
||||
replace_tensor("w2_weight_scale", marlin_w2_scales)
|
||||
|
||||
def create_moe_runner(
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
):
|
||||
self.moe_runner_config = moe_runner_config
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> CombineInput:
|
||||
|
||||
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
||||
|
||||
assert (
|
||||
moe_runner_config.activation == "silu"
|
||||
self.moe_runner_config.activation == "silu"
|
||||
), "Only SiLU activation is supported."
|
||||
|
||||
x = dispatch_output.hidden_states
|
||||
topk_output = dispatch_output.topk_output
|
||||
|
||||
topk_weights, topk_ids, router_logits = topk_output
|
||||
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
output = torch.ops.vllm.fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight_packed,
|
||||
layer.w2_weight_packed,
|
||||
@@ -670,3 +689,4 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
num_bits=self.num_bits,
|
||||
is_k_full=self.is_k_full,
|
||||
)
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
|
||||
@@ -30,6 +30,9 @@ except ImportError:
|
||||
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
||||
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
||||
from sglang.srt.layers.moe.token_dispatcher.base import DispatchOutputChecker
|
||||
from sglang.srt.layers.parameter import (
|
||||
BlockQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
@@ -81,7 +84,11 @@ from sglang.srt.utils import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
CombineInput,
|
||||
DispatchOutput,
|
||||
StandardDispatchOutput,
|
||||
)
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
||||
|
||||
@@ -527,7 +534,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer: Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
@@ -543,18 +550,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
|
||||
# Required by column parallel or enabling merged weights
|
||||
if intermediate_size % block_n != 0:
|
||||
if intermediate_size_per_partition % block_n != 0:
|
||||
raise ValueError(
|
||||
f"The output_size of gate's and up's weight = "
|
||||
f"{intermediate_size} is not divisible by "
|
||||
f"{intermediate_size_per_partition} is not divisible by "
|
||||
f"weight quantization block_n = {block_n}."
|
||||
)
|
||||
if tp_size > 1:
|
||||
# Required by row parallel
|
||||
if intermediate_size % block_k != 0:
|
||||
if intermediate_size_per_partition % block_k != 0:
|
||||
raise ValueError(
|
||||
f"The input_size of down's weight = "
|
||||
f"{intermediate_size} is not divisible by "
|
||||
f"{intermediate_size_per_partition} is not divisible by "
|
||||
f"weight quantization block_k = {block_k}."
|
||||
)
|
||||
|
||||
@@ -564,7 +571,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size // 8,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
@@ -572,20 +579,29 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts, hidden_size, intermediate_size // 8, dtype=params_dtype
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition // 8,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
else:
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts, hidden_size, intermediate_size, dtype=params_dtype
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
@@ -601,7 +617,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
2 * ((intermediate_size + block_n - 1) // block_n),
|
||||
2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
|
||||
(hidden_size + block_k - 1) // block_k,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
@@ -611,7 +627,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
torch.ones(
|
||||
num_experts,
|
||||
(hidden_size + block_n - 1) // block_n,
|
||||
(intermediate_size + block_k - 1) // block_k,
|
||||
(intermediate_size_per_partition + block_k - 1) // block_k,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
@@ -632,19 +648,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
self.c_strides1 = torch.full(
|
||||
(num_experts,),
|
||||
2 * intermediate_size,
|
||||
2 * intermediate_size_per_partition,
|
||||
device=w13_weight.device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
self.ab_strides2 = torch.full(
|
||||
(num_experts,),
|
||||
intermediate_size,
|
||||
intermediate_size_per_partition,
|
||||
device=w2_weight.device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
self.c_strides2 = torch.full(
|
||||
(num_experts,),
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
device=w2_weight.device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
@@ -691,7 +707,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
if _is_hip: # _use_aiter: TODO: add check back after triton kernel
|
||||
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
|
||||
w13_weight_scale1 = torch.nn.Parameter(
|
||||
torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
|
||||
torch.ones(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
w2_weight_scale1 = torch.nn.Parameter(
|
||||
@@ -984,14 +1004,23 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def create_moe_runner(
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
):
|
||||
self.moe_runner_config = moe_runner_config
|
||||
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
dispatch_output: DispatchOutput,
|
||||
) -> CombineInput:
|
||||
|
||||
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
||||
|
||||
x = dispatch_output.hidden_states
|
||||
topk_output = dispatch_output.topk_output
|
||||
moe_runner_config = self.moe_runner_config
|
||||
|
||||
if use_intel_amx_backend(layer):
|
||||
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
||||
@@ -1001,7 +1030,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
moe_runner_config.apply_router_weight_on_input, topk_weights, x
|
||||
)
|
||||
|
||||
return torch.ops.sgl_kernel.fused_experts_cpu(
|
||||
output = torch.ops.sgl_kernel.fused_experts_cpu(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
@@ -1017,6 +1046,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
None, # a2_scale
|
||||
True, # is_vnni
|
||||
)
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
|
||||
if _is_hip:
|
||||
ret = self.maybe_apply_hip_fused_experts(
|
||||
@@ -1027,7 +1057,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
moe_runner_config.no_combine,
|
||||
)
|
||||
if ret is not None:
|
||||
return ret
|
||||
return StandardCombineInput(hidden_states=ret)
|
||||
|
||||
if self.use_cutlass_fused_experts_fp8:
|
||||
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
|
||||
@@ -1056,17 +1086,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
self.problem_sizes2,
|
||||
use_fp8_blockscale=True,
|
||||
)
|
||||
# Scale by routed_scaling_factor is fused into select_experts.
|
||||
return output
|
||||
# Expert fusion with FP8 quantization
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_output=topk_output,
|
||||
moe_runner_config=moe_runner_config,
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
|
||||
quant_info = TritonMoeQuantInfo(
|
||||
w13_weight=layer.w13_weight,
|
||||
w2_weight=layer.w2_weight,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=(
|
||||
w13_scale=(
|
||||
layer.w13_weight_scale_inv
|
||||
if self.block_quant
|
||||
else layer.w13_weight_scale
|
||||
@@ -1074,20 +1100,22 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
w2_scale=(
|
||||
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
|
||||
),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a13_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
)
|
||||
return self.runner.run(dispatch_output, quant_info)
|
||||
|
||||
def apply_with_router_logits(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> torch.Tensor:
|
||||
activation = moe_runner_config.activation
|
||||
routed_scaling_factor = moe_runner_config.routed_scaling_factor
|
||||
x = dispatch_output.hidden_states
|
||||
topk_output = dispatch_output.topk_output
|
||||
|
||||
activation = self.moe_runner_config.activation
|
||||
routed_scaling_factor = self.moe_runner_config.routed_scaling_factor
|
||||
|
||||
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
|
||||
|
||||
|
||||
@@ -45,7 +45,10 @@ from sglang.srt.layers.quantization.utils import (
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
StandardDispatchOutput,
|
||||
CombineInput,
|
||||
)
|
||||
|
||||
from sglang.srt.utils import is_cuda
|
||||
|
||||
@@ -838,19 +841,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
from sglang.srt.layers.linear import set_weight_attrs
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||
|
||||
intermediate_size = extra_weight_attrs.pop("intermediate_size")
|
||||
|
||||
self.is_k_full = (not self.quant_config.desc_act) or (
|
||||
intermediate_size_per_partition == intermediate_size
|
||||
)
|
||||
self.is_k_full = (not self.quant_config.desc_act) or layer.moe_tp_size == 1
|
||||
|
||||
if self.quant_config.group_size != -1:
|
||||
scales_size13 = hidden_size // self.quant_config.group_size
|
||||
w2_scales_size = (
|
||||
intermediate_size
|
||||
if self.quant_config.desc_act
|
||||
else intermediate_size_per_partition
|
||||
)
|
||||
if self.quant_config.desc_act:
|
||||
w2_scales_size = intermediate_size_per_partition
|
||||
else:
|
||||
w2_scales_size = intermediate_size_per_partition * layer.moe_tp_size
|
||||
scales_size2 = w2_scales_size // self.quant_config.group_size
|
||||
strategy = FusedMoeWeightScaleSupported.GROUP.value
|
||||
else:
|
||||
@@ -1052,17 +1050,26 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
replace_parameter(layer, "w2_scales", marlin_w2_scales)
|
||||
|
||||
def create_moe_runner(
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
):
|
||||
self.moe_runner_config = moe_runner_config
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> CombineInput:
|
||||
|
||||
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
||||
|
||||
x = dispatch_output.hidden_states
|
||||
topk_output = dispatch_output.topk_output
|
||||
|
||||
# Delay the import to avoid circular dependency
|
||||
|
||||
assert (
|
||||
moe_runner_config.activation == "silu"
|
||||
self.moe_runner_config.activation == "silu"
|
||||
), "Only SiLU activation is supported."
|
||||
|
||||
# The input must currently be float16
|
||||
@@ -1071,7 +1078,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
topk_weights, topk_ids, router_logits = topk_output
|
||||
|
||||
return fused_marlin_moe(
|
||||
output = fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_qweight,
|
||||
layer.w2_qweight,
|
||||
@@ -1087,3 +1094,4 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
num_bits=self.quant_config.weight_bits,
|
||||
is_k_full=self.is_k_full,
|
||||
).to(orig_dtype)
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
|
||||
@@ -10,10 +10,14 @@ from torch.nn.parameter import Parameter
|
||||
from sglang.srt.distributed import get_tp_group
|
||||
from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer
|
||||
from sglang.srt.layers.moe import (
|
||||
MoeRunner,
|
||||
MoeRunnerBackend,
|
||||
MoeRunnerConfig,
|
||||
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
||||
should_use_flashinfer_trtllm_moe,
|
||||
)
|
||||
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
|
||||
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
||||
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
@@ -39,8 +43,10 @@ from sglang.srt.utils import is_cuda, next_power_of_2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
CombineInput,
|
||||
StandardDispatchOutput,
|
||||
)
|
||||
|
||||
if is_cuda():
|
||||
from sgl_kernel import scaled_fp4_quant
|
||||
@@ -322,7 +328,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
@@ -338,7 +344,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
w13_weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
num_experts, 2 * intermediate_size, hidden_size, dtype=weight_dtype
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=weight_dtype,
|
||||
),
|
||||
input_dim=2,
|
||||
output_dim=1,
|
||||
@@ -348,7 +357,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
w2_weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
num_experts, hidden_size, intermediate_size, dtype=weight_dtype
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=weight_dtype,
|
||||
),
|
||||
input_dim=2,
|
||||
output_dim=1,
|
||||
@@ -414,28 +426,28 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||
|
||||
# Requantize each expert's weights using the combined scale
|
||||
# w13_weight has shape (num_experts, 2 * intermediate_size, hidden_size)
|
||||
# where the first intermediate_size rows are w1, the next are w3
|
||||
intermediate_size = layer.w13_weight.shape[1] // 2
|
||||
# w13_weight has shape (num_experts, 2 * intermediate_size_per_partition, hidden_size)
|
||||
# where the first intermediate_size_per_partition rows are w1, the next are w3
|
||||
intermediate_size_per_partition = layer.w13_weight.shape[1] // 2
|
||||
for expert_id in range(layer.w13_weight.shape[0]):
|
||||
start = 0
|
||||
for shard_id in range(2): # w1 and w3
|
||||
# Dequantize using the original scale for this shard
|
||||
dq_weight = per_tensor_dequantize(
|
||||
layer.w13_weight[expert_id][
|
||||
start : start + intermediate_size, :
|
||||
start : start + intermediate_size_per_partition, :
|
||||
],
|
||||
layer.w13_weight_scale[expert_id][shard_id],
|
||||
)
|
||||
# Requantize using the combined max scale
|
||||
(
|
||||
layer.w13_weight[expert_id][
|
||||
start : start + intermediate_size, :
|
||||
start : start + intermediate_size_per_partition, :
|
||||
],
|
||||
_,
|
||||
) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
||||
|
||||
start += intermediate_size
|
||||
start += intermediate_size_per_partition
|
||||
|
||||
# Update the scale parameter to be per-expert instead of per-shard
|
||||
layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
|
||||
@@ -457,29 +469,31 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_input_scale.max(), requires_grad=False
|
||||
)
|
||||
|
||||
def create_moe_runner(
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
):
|
||||
self.moe_runner_config = moe_runner_config
|
||||
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> CombineInput:
|
||||
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_output=topk_output,
|
||||
moe_runner_config=moe_runner_config,
|
||||
quant_info = TritonMoeQuantInfo(
|
||||
w13_weight=layer.w13_weight,
|
||||
w2_weight=layer.w2_weight,
|
||||
use_fp8_w8a8=True,
|
||||
per_channel_quant=False, # ModelOpt uses per-tensor quantization
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
per_channel_quant=False,
|
||||
w13_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a13_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
)
|
||||
|
||||
return self.runner.run(dispatch_output, quant_info)
|
||||
|
||||
|
||||
class ModelOptFp4Config(QuantizationConfig):
|
||||
"""Config class for FP4."""
|
||||
@@ -1278,21 +1292,32 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
|
||||
return self.enable_flashinfer_cutlass_moe
|
||||
|
||||
def create_moe_runner(
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
):
|
||||
self.moe_runner_config = moe_runner_config
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> CombineInput:
|
||||
|
||||
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
||||
|
||||
x = dispatch_output.hidden_states
|
||||
topk_output = dispatch_output.topk_output
|
||||
|
||||
assert (
|
||||
moe_runner_config.activation == "silu"
|
||||
self.moe_runner_config.activation == "silu"
|
||||
), "Only SiLU activation is supported."
|
||||
|
||||
moe_runner_config = self.moe_runner_config
|
||||
|
||||
# Check if this is a FlashInferFP4MoE layer that should handle its own forward
|
||||
if hasattr(layer, "gemm1_weights_fp4_shuffled"):
|
||||
# This layer was processed with flashinfer TRTLLM - delegate to its own forward
|
||||
return layer.forward(x, topk_output)
|
||||
return StandardCombineInput(hidden_states=layer.forward(x, topk_output))
|
||||
|
||||
if self.enable_flashinfer_cutlass_moe:
|
||||
assert (
|
||||
@@ -1345,13 +1370,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
tp_rank=layer.moe_tp_rank,
|
||||
tune_max_num_tokens=next_power_of_2(x.shape[0]),
|
||||
)[0]
|
||||
# Scale by routed_scaling_factor is fused into select_experts.
|
||||
if should_use_flashinfer_cutlass_moe_fp4_allgather():
|
||||
output, global_output = get_local_dp_buffer(), output
|
||||
get_tp_group().reduce_scatterv(
|
||||
global_output, output=output, sizes=get_dp_global_num_tokens()
|
||||
)
|
||||
return output
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
|
||||
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
||||
|
||||
@@ -1372,4 +1396,5 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
|
||||
).to(x.dtype)
|
||||
# Scale by routed_scaling_factor is fused into select_experts.
|
||||
return output
|
||||
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
|
||||
@@ -9,6 +9,8 @@ import torch
|
||||
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||
from sglang.srt.distributed.parallel_state import get_tp_group
|
||||
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
||||
from sglang.srt.layers.quantization.awq import AWQConfig
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
@@ -22,8 +24,10 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
CombineInput,
|
||||
StandardDispatchOutput,
|
||||
)
|
||||
|
||||
|
||||
def get_weight_perm(num_bits: int):
|
||||
@@ -349,37 +353,36 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
||||
layer.register_parameter(key, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
|
||||
def create_moe_runner(
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
):
|
||||
self.moe_runner_config = moe_runner_config
|
||||
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
# avoid circular import
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> CombineInput:
|
||||
assert (
|
||||
moe_runner_config.activation == "silu"
|
||||
self.moe_runner_config.activation == "silu"
|
||||
), "Only SiLU activation is supported."
|
||||
|
||||
weight_bits = self.quant_config.weight_bits
|
||||
has_zp = self.quant_config.has_zp
|
||||
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_qweight,
|
||||
layer.w2_qweight,
|
||||
topk_output=topk_output,
|
||||
moe_runner_config=moe_runner_config,
|
||||
quant_info = TritonMoeQuantInfo(
|
||||
w13_weight=layer.w13_qweight,
|
||||
w2_weight=layer.w2_qweight,
|
||||
use_int4_w4a16=weight_bits == 4,
|
||||
use_int8_w8a16=weight_bits == 8,
|
||||
w1_scale=layer.w13_scales,
|
||||
w13_scale=layer.w13_scales,
|
||||
w2_scale=layer.w2_scales,
|
||||
w1_zp=layer.w13_qzeros if has_zp else None,
|
||||
w13_zp=layer.w13_qzeros if has_zp else None,
|
||||
w2_zp=layer.w2_qzeros if has_zp else None,
|
||||
block_shape=[0, layer.group_size],
|
||||
)
|
||||
return self.runner.run(dispatch_output, quant_info)
|
||||
|
||||
@staticmethod
|
||||
def get_weight_loader(layer, weight_loader):
|
||||
|
||||
@@ -22,6 +22,8 @@ from typing import TYPE_CHECKING, List, Optional
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
||||
from sglang.srt.layers.moe.utils import get_moe_runner_backend
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
@@ -59,8 +61,10 @@ if is_flashinfer_available():
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
CombineInput,
|
||||
StandardDispatchOutput,
|
||||
)
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
@@ -283,7 +287,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
with_bias: bool = False,
|
||||
**extra_weight_attrs,
|
||||
@@ -296,26 +300,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
# pad the intermediate size to be a multiple of 2 * mxfp4_block
|
||||
# for to hold non-uniform sharded tensor as well as swizzling
|
||||
intermediate_size_per_partition_after_pad = intermediate_size
|
||||
intermediate_size_per_partition_after_pad = intermediate_size_per_partition
|
||||
if _is_sm100_supported:
|
||||
if self.use_flashinfer:
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size, 256
|
||||
intermediate_size_per_partition, 256
|
||||
)
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
else:
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size, 64
|
||||
intermediate_size_per_partition, 64
|
||||
)
|
||||
elif has_triton_kernels:
|
||||
# TODO: this is a hack to make
|
||||
# intermediate_size_per_partition_after_pad the same as the
|
||||
# per_rank_intermediate_size during weight loading
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size, mxfp4_block
|
||||
intermediate_size_per_partition, mxfp4_block
|
||||
)
|
||||
|
||||
self.intermediate_size = intermediate_size_per_partition_after_pad
|
||||
self.intermediate_size_per_partition = intermediate_size_per_partition_after_pad
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
# Fused gate_up_proj (column parallel)
|
||||
@@ -410,31 +414,35 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
assert (
|
||||
layer.w13_weight.dim() == 3
|
||||
and layer.w13_weight.shape[0] == self.num_experts
|
||||
and layer.w13_weight.shape[1] == self.intermediate_size * 2
|
||||
and layer.w13_weight.shape[1]
|
||||
== self.intermediate_size_per_partition * 2
|
||||
and layer.w13_weight.shape[2] == self.hidden_size // 2
|
||||
)
|
||||
assert (
|
||||
layer.w13_weight_scale.dim() == 3
|
||||
and layer.w13_weight_scale.shape[0] == self.num_experts
|
||||
and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2
|
||||
and layer.w13_weight_scale.shape[1]
|
||||
== self.intermediate_size_per_partition * 2
|
||||
and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size
|
||||
)
|
||||
assert (
|
||||
layer.w2_weight.dim() == 3
|
||||
and layer.w2_weight.shape[0] == self.num_experts
|
||||
and layer.w2_weight.shape[1] == self.hidden_size
|
||||
and layer.w2_weight.shape[2] == self.intermediate_size // 2
|
||||
and layer.w2_weight.shape[2]
|
||||
== self.intermediate_size_per_partition // 2
|
||||
)
|
||||
assert (
|
||||
layer.w2_weight_scale.dim() == 3
|
||||
and layer.w2_weight_scale.shape[1] == self.hidden_size
|
||||
and layer.w2_weight_scale.shape[2]
|
||||
== self.intermediate_size // sf_block_size
|
||||
== self.intermediate_size_per_partition // sf_block_size
|
||||
)
|
||||
assert (
|
||||
layer.w13_weight_bias.dim() == 2
|
||||
and layer.w13_weight_bias.shape[0] == self.num_experts
|
||||
and layer.w13_weight_bias.shape[1] == self.intermediate_size * 2
|
||||
and layer.w13_weight_bias.shape[1]
|
||||
== self.intermediate_size_per_partition * 2
|
||||
)
|
||||
assert (
|
||||
layer.w2_weight_bias.dim() == 2
|
||||
@@ -511,7 +519,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
torch.stack(gemm1_scales_mxfp4_shuffled)
|
||||
.reshape(
|
||||
self.num_experts,
|
||||
2 * self.intermediate_size,
|
||||
2 * self.intermediate_size_per_partition,
|
||||
self.hidden_size // sf_block_size,
|
||||
)
|
||||
.view(torch.float8_e4m3fn)
|
||||
@@ -523,7 +531,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
.reshape(
|
||||
self.num_experts,
|
||||
self.hidden_size,
|
||||
self.intermediate_size // sf_block_size,
|
||||
self.intermediate_size_per_partition // sf_block_size,
|
||||
)
|
||||
.view(torch.float8_e4m3fn)
|
||||
)
|
||||
@@ -613,16 +621,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
return tile_tokens_dim
|
||||
|
||||
def create_moe_runner(
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
):
|
||||
self.moe_runner_config = moe_runner_config
|
||||
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> CombineInput:
|
||||
|
||||
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
||||
from sglang.srt.layers.moe.topk import TopKOutputChecker
|
||||
|
||||
x = dispatch_output.hidden_states
|
||||
topk_output = dispatch_output.topk_output
|
||||
|
||||
moe_runner_config = self.moe_runner_config
|
||||
|
||||
if self.use_flashinfer:
|
||||
# When bf16 mode is enabled, we don't need to quantize the input,
|
||||
# TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations,
|
||||
@@ -674,7 +692,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
top_k,
|
||||
None, # n_group # TODO: support n_group
|
||||
None, # topk_group # TODO: support topk_group
|
||||
self.intermediate_size, # padded to multiple of 256
|
||||
self.intermediate_size_per_partition, # padded to multiple of 256
|
||||
layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset
|
||||
layer.num_local_experts, # local num experts
|
||||
None,
|
||||
@@ -682,14 +700,14 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
1, # routing_method_type, renormalize
|
||||
True, # do finalize
|
||||
)[0]
|
||||
return trtllm_gen_output
|
||||
return StandardCombineInput(hidden_states=trtllm_gen_output)
|
||||
|
||||
if self.use_triton_kernels:
|
||||
assert (
|
||||
layer.moe_ep_size == 1
|
||||
), "Expert parallel is not supported when using triton kernels"
|
||||
if self.with_bias:
|
||||
return self.triton_kernel_moe_with_bias_forward(
|
||||
output = self.triton_kernel_moe_with_bias_forward(
|
||||
hidden_states=x,
|
||||
w1=self.w13_weight_triton_tensor,
|
||||
w1_pcg=self.w13_precision_config,
|
||||
@@ -701,25 +719,22 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
moe_runner_config=moe_runner_config,
|
||||
)
|
||||
else:
|
||||
return self.triton_kernel_moe_forward(
|
||||
output = self.triton_kernel_moe_forward(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_output=topk_output,
|
||||
moe_runner_config=moe_runner_config,
|
||||
)
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
else:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_output=topk_output,
|
||||
moe_runner_config=moe_runner_config,
|
||||
b1=layer.w13_weight_bias,
|
||||
b2=layer.w2_weight_bias,
|
||||
quant_info = TritonMoeQuantInfo(
|
||||
w13_weight=layer.w13_weight,
|
||||
w2_weight=layer.w2_weight,
|
||||
w13_weight_bias=layer.w13_weight_bias,
|
||||
w2_weight_bias=layer.w2_weight_bias,
|
||||
)
|
||||
return self.runner.run(dispatch_output, quant_info)
|
||||
|
||||
|
||||
class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
|
||||
@@ -798,7 +813,7 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
return w, mx_scales
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
w13, w13_mx_scales = self.mxfp4_quantize(layer.w13_weight.data)
|
||||
w2, w2_mx_scales = self.mxfp4_quantize(layer.w2_weight.data)
|
||||
|
||||
@@ -808,19 +823,27 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(w2_mx_scales, requires_grad=False)
|
||||
|
||||
def create_moe_runner(
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
):
|
||||
self.moe_runner_config = moe_runner_config
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> CombineInput:
|
||||
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
||||
|
||||
x = dispatch_output.hidden_states
|
||||
topk_output = dispatch_output.topk_output
|
||||
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
if _is_hip:
|
||||
topk_weights = topk_weights.to(
|
||||
torch.float32
|
||||
) # aiter's moe_sorting requires topk_weights to be FP32
|
||||
return fused_moe(
|
||||
output = fused_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
@@ -831,8 +854,9 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
activation=(
|
||||
ActivationType.Silu
|
||||
if moe_runner_config.activation == "silu"
|
||||
if self.moe_runner_config.activation == "silu"
|
||||
else ActivationType.Gelu
|
||||
),
|
||||
doweight_stage1=False,
|
||||
)
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
|
||||
@@ -10,8 +10,17 @@ from aiter import ActivationType, QuantType, biased_grouped_topk
|
||||
from aiter.fused_moe import fused_moe
|
||||
from aiter.utility.fp4_utils import e8m0_shuffle
|
||||
|
||||
from sglang.srt.layers.moe import MoeRunnerConfig
|
||||
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
|
||||
from sglang.srt.utils import get_bool_env_var, mxfp_supported, set_weight_attrs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
CombineInput,
|
||||
StandardDispatchOutput,
|
||||
)
|
||||
from sglang.srt.layers.quantization.quark.quark import QuarkConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"]
|
||||
@@ -19,31 +28,17 @@ __all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"]
|
||||
OCP_MX_BLOCK_SIZE = 32
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
from sglang.srt.layers.quantization import QuarkConfig
|
||||
|
||||
|
||||
class QuarkMoEMethod:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
|
||||
class QuarkMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
if not hasattr(cls, "_initialized"):
|
||||
original_init = cls.__init__
|
||||
new_cls = type(
|
||||
cls.__name__,
|
||||
(FusedMoEMethodBase,),
|
||||
{
|
||||
"__init__": original_init,
|
||||
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
||||
},
|
||||
)
|
||||
obj = super(new_cls, new_cls).__new__(new_cls)
|
||||
obj.__init__(*args, **kwargs)
|
||||
return obj
|
||||
return super().__new__(cls)
|
||||
def __init__(self, quant_config: QuarkConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
@staticmethod
|
||||
def get_moe_method(
|
||||
quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821
|
||||
quant_config: QuarkConfig, # type: ignore # noqa E501 # noqa F821
|
||||
module: torch.nn.Module,
|
||||
layer_name: str,
|
||||
) -> "QuarkMoEMethod":
|
||||
@@ -170,16 +165,25 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
|
||||
# layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, requires_grad=False)
|
||||
layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1)
|
||||
|
||||
def create_moe_runner(
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
):
|
||||
self.moe_runner_config = moe_runner_config
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> CombineInput:
|
||||
|
||||
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
||||
|
||||
x = dispatch_output.hidden_states
|
||||
topk_output = dispatch_output.topk_output
|
||||
moe_runner_config = self.moe_runner_config
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
|
||||
return fused_moe(
|
||||
output = fused_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
@@ -195,3 +199,4 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
|
||||
),
|
||||
doweight_stage1=False,
|
||||
)
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
|
||||
@@ -9,6 +9,8 @@ from torch.nn.parameter import Parameter
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
||||
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
LinearMethodBase,
|
||||
@@ -24,8 +26,10 @@ from sglang.srt.utils import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
CombineInput,
|
||||
StandardDispatchOutput,
|
||||
)
|
||||
|
||||
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
||||
|
||||
@@ -155,7 +159,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
with_bias: bool = False,
|
||||
**extra_weight_attrs,
|
||||
@@ -163,7 +167,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
self.with_bias = with_bias
|
||||
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size
|
||||
w13_weight_n, w13_weight_k = 2 * intermediate_size_per_partition, hidden_size
|
||||
if self.use_triton_kernels:
|
||||
w13_weight_n, w13_weight_k = w13_weight_k, w13_weight_n
|
||||
w13_weight = torch.nn.Parameter(
|
||||
@@ -175,7 +179,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
|
||||
if self.with_bias:
|
||||
w13_weight_bias = torch.nn.Parameter(
|
||||
torch.empty(num_experts, 2 * intermediate_size, dtype=torch.float32),
|
||||
torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_bias", w13_weight_bias)
|
||||
@@ -184,7 +192,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
# down_proj (row parallel)
|
||||
w2_weight_n, w2_weight_k = (
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
intermediate_size_per_partition,
|
||||
)
|
||||
if self.use_triton_kernels:
|
||||
w2_weight_n, w2_weight_k = w2_weight_k, w2_weight_n
|
||||
@@ -222,33 +230,40 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
|
||||
return
|
||||
|
||||
def create_moe_runner(
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
):
|
||||
self.moe_runner_config = moe_runner_config
|
||||
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> CombineInput:
|
||||
|
||||
return self.forward(
|
||||
x=x,
|
||||
layer=layer,
|
||||
topk_output=topk_output,
|
||||
moe_runner_config=moe_runner_config,
|
||||
dispatch_output=dispatch_output,
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> CombineInput:
|
||||
|
||||
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
||||
|
||||
x = dispatch_output.hidden_states
|
||||
topk_output = dispatch_output.topk_output
|
||||
|
||||
moe_runner_config = self.moe_runner_config
|
||||
|
||||
if self.use_triton_kernels:
|
||||
if self.with_bias:
|
||||
assert self.triton_kernel_moe_with_bias_forward is not None
|
||||
return self.triton_kernel_moe_with_bias_forward(
|
||||
output = self.triton_kernel_moe_with_bias_forward(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
@@ -261,13 +276,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
)
|
||||
else:
|
||||
assert self.triton_kernel_moe_forward is not None
|
||||
return self.triton_kernel_moe_forward(
|
||||
output = self.triton_kernel_moe_forward(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_output=topk_output,
|
||||
moe_runner_config=moe_runner_config,
|
||||
)
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
else:
|
||||
if _use_aiter:
|
||||
assert not moe_runner_config.no_combine, "unsupported"
|
||||
@@ -284,7 +300,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
topk_weights = torch.ones_like(
|
||||
topk_weights, dtype=torch.float32
|
||||
) # topk_weights must be FP32 (float32)
|
||||
return fused_moe(
|
||||
output = fused_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
@@ -296,28 +312,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
else ActivationType.Gelu
|
||||
),
|
||||
)
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
else:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||
fused_experts,
|
||||
)
|
||||
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
b1=getattr(layer, "w13_weight_bias", None),
|
||||
quant_info = TritonMoeQuantInfo(
|
||||
w13_weight=layer.w13_weight,
|
||||
w2_weight=layer.w2_weight,
|
||||
b13=getattr(layer, "w13_weight_bias", None),
|
||||
b2=getattr(layer, "w2_weight_bias", None),
|
||||
topk_output=topk_output,
|
||||
moe_runner_config=moe_runner_config,
|
||||
)
|
||||
return self.runner.run(dispatch_output, quant_info)
|
||||
|
||||
def forward_cpu(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> CombineInput:
|
||||
|
||||
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
||||
|
||||
x = dispatch_output.hidden_states
|
||||
topk_output = dispatch_output.topk_output
|
||||
|
||||
moe_runner_config = self.moe_runner_config
|
||||
|
||||
assert (
|
||||
moe_runner_config.activation == "silu"
|
||||
), f"activation = {moe_runner_config.activation} is not supported."
|
||||
@@ -332,7 +350,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
x, topk_weights = apply_topk_weights_cpu(
|
||||
moe_runner_config.apply_router_weight_on_input, topk_weights, x
|
||||
)
|
||||
return torch.ops.sgl_kernel.fused_experts_cpu(
|
||||
output = torch.ops.sgl_kernel.fused_experts_cpu(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
@@ -348,33 +366,39 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
None, # a2_scale
|
||||
True, # is_vnni
|
||||
)
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
else:
|
||||
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
|
||||
|
||||
return moe_forward_native(
|
||||
output = moe_forward_native(
|
||||
layer,
|
||||
x,
|
||||
topk_output,
|
||||
moe_runner_config,
|
||||
)
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
|
||||
def forward_npu(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> CombineInput:
|
||||
|
||||
return moe_forward_native(
|
||||
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
|
||||
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
||||
|
||||
x = dispatch_output.hidden_states
|
||||
topk_output = dispatch_output.topk_output
|
||||
|
||||
output = moe_forward_native(
|
||||
layer,
|
||||
x,
|
||||
topk_output,
|
||||
moe_runner_config,
|
||||
self.moe_runner_config,
|
||||
)
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
|
||||
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
|
||||
def forward_tpu(self, *args, **kwargs) -> CombineInput:
|
||||
raise NotImplementedError("The TPU backend currently does not support MoE.")
|
||||
|
||||
forward_native = forward_cpu
|
||||
|
||||
@@ -9,6 +9,7 @@ from torch.nn.parameter import Parameter
|
||||
|
||||
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
||||
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
QuantizationConfig,
|
||||
@@ -22,7 +23,10 @@ from sglang.srt.utils import set_weight_attrs
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
CombineInput,
|
||||
StandardDispatchOutput,
|
||||
)
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
@@ -133,7 +137,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
layer: EPMoE,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
@@ -145,7 +149,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size * 2,
|
||||
intermediate_size_per_partition * 2,
|
||||
hidden_size // 2,
|
||||
dtype=torch.int8,
|
||||
),
|
||||
@@ -159,7 +163,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size // 2,
|
||||
intermediate_size_per_partition // 2,
|
||||
dtype=torch.int8,
|
||||
),
|
||||
requires_grad=False,
|
||||
@@ -173,7 +177,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size // self.quant_config.group_size,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
@@ -186,7 +190,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size // self.quant_config.group_size,
|
||||
intermediate_size_per_partition // self.quant_config.group_size,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
@@ -220,13 +224,13 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
self.c_strides1 = torch.full(
|
||||
(num_experts, 3),
|
||||
2 * intermediate_size,
|
||||
2 * intermediate_size_per_partition,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
self.a_strides2 = torch.full(
|
||||
(num_experts, 3),
|
||||
intermediate_size,
|
||||
intermediate_size_per_partition,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
@@ -282,16 +286,21 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False)
|
||||
|
||||
def create_moe_runner(
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
):
|
||||
self.moe_runner_config = moe_runner_config
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: EPMoE,
|
||||
x: torch.Tensor,
|
||||
topk_output: StandardTopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> CombineInput:
|
||||
|
||||
# TODO(ch-wan): move it out of this class
|
||||
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
||||
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
||||
|
||||
x = dispatch_output.hidden_states
|
||||
topk_output = dispatch_output.topk_output
|
||||
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
local_topk_ids = topk_ids
|
||||
@@ -328,6 +337,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w13_input_scale,
|
||||
layer.w2_input_scale,
|
||||
)
|
||||
if moe_runner_config.routed_scaling_factor is not None:
|
||||
output *= moe_runner_config.routed_scaling_factor
|
||||
return output
|
||||
if self.moe_runner_config.routed_scaling_factor is not None:
|
||||
output *= self.moe_runner_config.routed_scaling_factor
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
||||
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
@@ -26,8 +27,11 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
||||
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
CombineInput,
|
||||
StandardDispatchOutput,
|
||||
)
|
||||
|
||||
_is_fp8_fnuz = is_fp8_fnuz()
|
||||
|
||||
@@ -209,7 +213,7 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
@@ -218,7 +222,10 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts, 2 * intermediate_size, hidden_size, dtype=fp8_dtype
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=fp8_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
@@ -226,14 +233,21 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(num_experts, hidden_size, intermediate_size, dtype=fp8_dtype),
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=fp8_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
|
||||
torch.ones(
|
||||
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
@@ -266,25 +280,26 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_weight_scale.data, requires_grad=False
|
||||
)
|
||||
|
||||
def create_moe_runner(
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
):
|
||||
self.moe_runner_config = moe_runner_config
|
||||
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: StandardTopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> CombineInput:
|
||||
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_output=topk_output,
|
||||
moe_runner_config=moe_runner_config,
|
||||
quant_info = TritonMoeQuantInfo(
|
||||
w13_weight=layer.w13_weight,
|
||||
w2_weight=layer.w2_weight,
|
||||
use_fp8_w8a8=True,
|
||||
per_channel_quant=True,
|
||||
w1_scale=(layer.w13_weight_scale),
|
||||
w2_scale=(layer.w2_weight_scale),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
w13_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a13_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
)
|
||||
return self.runner.run(dispatch_output, quant_info)
|
||||
|
||||
@@ -24,6 +24,8 @@ from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
||||
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
||||
from sglang.srt.layers.parameter import (
|
||||
ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
@@ -49,8 +51,10 @@ from sglang.srt.utils import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
CombineInput,
|
||||
StandardDispatchOutput,
|
||||
)
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
@@ -417,7 +421,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
@@ -428,7 +432,10 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=torch.int8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
@@ -436,14 +443,21 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8),
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
|
||||
torch.ones(
|
||||
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
@@ -483,23 +497,30 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_weight_scale.data, requires_grad=False
|
||||
)
|
||||
|
||||
def create_moe_runner(
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
):
|
||||
self.moe_runner_config = moe_runner_config
|
||||
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
||||
|
||||
x = dispatch_output.hidden_states
|
||||
topk_output = dispatch_output.topk_output
|
||||
|
||||
if use_intel_amx_backend(layer):
|
||||
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
||||
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
x, topk_weights = apply_topk_weights_cpu(
|
||||
moe_runner_config.apply_router_weight_on_input, topk_weights, x
|
||||
self.moe_runner_config.apply_router_weight_on_input, topk_weights, x
|
||||
)
|
||||
return torch.ops.sgl_kernel.fused_experts_cpu(
|
||||
output = torch.ops.sgl_kernel.fused_experts_cpu(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
@@ -515,20 +536,19 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_input_scale, # a2_scale
|
||||
True, # is_vnni
|
||||
)
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_output=topk_output,
|
||||
moe_runner_config=moe_runner_config,
|
||||
quant_info = TritonMoeQuantInfo(
|
||||
w13_weight=layer.w13_weight,
|
||||
w2_weight=layer.w2_weight,
|
||||
use_int8_w8a8=True,
|
||||
per_channel_quant=True,
|
||||
w1_scale=(layer.w13_weight_scale),
|
||||
w2_scale=(layer.w2_weight_scale),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
w13_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a13_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
)
|
||||
return self.runner.run(dispatch_output, quant_info)
|
||||
|
||||
|
||||
class NPU_W8A8LinearMethodImpl:
|
||||
@@ -900,7 +920,7 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
@@ -914,21 +934,31 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
|
||||
# weight
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=torch.int8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8),
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
# scale
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.empty(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
|
||||
torch.empty(
|
||||
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
@@ -941,7 +971,9 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
# offset
|
||||
w13_weight_offset = torch.nn.Parameter(
|
||||
torch.empty(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
|
||||
torch.empty(
|
||||
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_offset", w13_weight_offset)
|
||||
@@ -973,18 +1005,25 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_weight_offset.data.squeeze(-1).contiguous(), requires_grad=False
|
||||
)
|
||||
|
||||
def create_moe_runner(
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
):
|
||||
self.moe_runner_config = moe_runner_config
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer,
|
||||
x,
|
||||
topk_output: TopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> CombineInput:
|
||||
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
||||
|
||||
x = dispatch_output.hidden_states
|
||||
topk_output = dispatch_output.topk_output
|
||||
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
topk_weights = topk_weights.to(x.dtype)
|
||||
return npu_fused_experts(
|
||||
output = npu_fused_experts(
|
||||
hidden_states=x,
|
||||
w13=layer.w13_weight,
|
||||
w13_scale=layer.w13_weight_scale,
|
||||
@@ -994,3 +1033,4 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
|
||||
topk_ids=topk_ids,
|
||||
top_k=topk_ids.shape[1],
|
||||
)
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
|
||||
@@ -52,7 +52,6 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
||||
ScheduleBatchDisaggregationDecodeMixin,
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
||||
from sglang.srt.layers.moe import is_tbo_enabled
|
||||
from sglang.srt.mem_cache.allocator import (
|
||||
BaseTokenToKVPoolAllocator,
|
||||
SWATokenToKVPoolAllocator,
|
||||
|
||||
@@ -1,16 +1,22 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/__init__.py
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.configs.device_config import DeviceConfig
|
||||
from sglang.srt.configs.load_config import LoadConfig
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.model_loader.loader import BaseModelLoader, get_model_loader
|
||||
from sglang.srt.model_loader.utils import (
|
||||
get_architecture_class_name,
|
||||
get_model_architecture,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.configs.device_config import DeviceConfig
|
||||
from sglang.srt.configs.load_config import LoadConfig
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
|
||||
|
||||
def get_model(
|
||||
*,
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/model_executor/model_loader/loader.py
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# ruff: noqa: SIM117
|
||||
import collections
|
||||
import concurrent
|
||||
@@ -14,7 +16,17 @@ import time
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
cast,
|
||||
)
|
||||
|
||||
import huggingface_hub
|
||||
import numpy as np
|
||||
@@ -26,9 +38,7 @@ from tqdm.auto import tqdm
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from sglang.srt.configs.device_config import DeviceConfig
|
||||
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.connector import (
|
||||
ConnectorType,
|
||||
create_remote_connector,
|
||||
@@ -39,7 +49,6 @@ from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.model_loader.utils import (
|
||||
get_model_architecture,
|
||||
post_load_weights,
|
||||
@@ -70,6 +79,11 @@ from sglang.srt.utils import (
|
||||
set_weight_attrs,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.configs.device_config import DeviceConfig
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
|
||||
_is_npu = is_npu()
|
||||
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from transformers import AutoConfig
|
||||
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
||||
|
||||
|
||||
# Copy from: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/utils.py
|
||||
@@ -152,14 +153,32 @@ def run_test(tp_size, batch_size, model_config, check=False):
|
||||
problem_sizes2,
|
||||
)
|
||||
|
||||
topk_output = StandardTopKOutput(
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
router_logits=torch.randn(
|
||||
(batch_size, topk), device=topk_weights.device, dtype=dtype
|
||||
),
|
||||
)
|
||||
|
||||
moe_runner_config = MoeRunnerConfig(
|
||||
num_experts=E,
|
||||
topk=topk,
|
||||
hidden_size=H,
|
||||
shard_intermediate_size=I,
|
||||
dtype=dtype,
|
||||
block_shape=block_shape,
|
||||
activation="silu",
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
# Note: Triton expects non-transposed weights
|
||||
moe_config = MoeRunnerConfig(inplace=False)
|
||||
triton_lambda = lambda: fused_experts(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
(topk_weights, topk_ids, "dummy"),
|
||||
moe_config,
|
||||
topk_output,
|
||||
moe_runner_config,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
@@ -224,8 +243,8 @@ def run_test(tp_size, batch_size, model_config, check=False):
|
||||
x,
|
||||
w1, # Original shape
|
||||
w2, # Original shape
|
||||
(topk_weights, topk_ids, "dummy"),
|
||||
moe_config,
|
||||
topk_output,
|
||||
moe_runner_config,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
|
||||
Reference in New Issue
Block a user