Add DeepSeek V3/R1 shared experts fusion (#4918)

This commit is contained in:
Xiaoyu Zhang
2025-04-04 16:59:29 +08:00
committed by GitHub
parent 6ff9c6a5e7
commit 924ca7c92c
14 changed files with 536 additions and 36 deletions

View File

@@ -51,7 +51,6 @@ except ImportError:
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.quantization.awq import AWQConfig
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
@@ -203,6 +202,8 @@ def get_linear_quant_method(
def gptq_get_quant_method(self, layer, prefix):
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
if isinstance(layer, FusedMoE):
return GPTQMarlinMoEMethod(self)

View File

@@ -23,7 +23,6 @@ from sglang.srt.layers.linear import (
LinearMethodBase,
UnquantizedLinearMethod,
)
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
@@ -123,6 +122,8 @@ class CompressedTensorsConfig(QuantizationConfig):
return UnquantizedLinearMethod()
layer.scheme = scheme
return CompressedTensorsLinearMethod(self)
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, FusedMoE):
return CompressedTensorsMoEMethod.get_moe_method(self)
return None

View File

@@ -4,18 +4,19 @@
import enum
import logging
from enum import Enum
from typing import Callable, List, Optional
from typing import TYPE_CHECKING, Callable, List, Optional
import torch
from compressed_tensors import CompressionFormat
from compressed_tensors.quantization import QuantizationStrategy
from sglang.srt.layers.moe.fused_moe_triton import (
FusedMoE,
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
)
from sglang.srt.layers.moe.topk import select_experts
if TYPE_CHECKING:
from sglang.srt.layers.moe.fused_moe_triton import (
FusedMoE,
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
)
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
from sglang.srt.layers.quantization.utils import (
all_close_1d,
@@ -55,7 +56,13 @@ __all__ = [
]
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
class CompressedTensorsMoEMethod:
def __new__(cls, *args, **kwargs):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
if cls is CompressedTensorsMoEMethod:
return super().__new__(cls)
return super().__new__(cls)
@staticmethod
def get_moe_method(
@@ -85,6 +92,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
):
from sglang.srt.layers.moe.fused_moe_triton import (
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
)
self.quant_config = quant_config
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
@@ -112,6 +124,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
params_dtype: torch.dtype,
**extra_weight_attrs,
):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
params_dtype = torch.float8_e4m3fn
@@ -270,8 +283,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
scoring_func: str = "softmax",
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
inplace: bool = True,
no_combine: bool = False,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.fused_moe_triton import fused_experts
from sglang.srt.layers.moe.topk import select_experts
topk_weights, topk_ids = select_experts(
hidden_states=x,
@@ -291,7 +307,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
inplace=inplace,
activation=activation,
use_fp8_w8a8=True,
w1_scale=layer.w13_weight_scale,
@@ -306,6 +322,11 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
):
from sglang.srt.layers.moe.fused_moe_triton import (
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
)
self.quant_config = quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
@@ -617,6 +638,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import select_experts
assert activation == "silu", "Only SiLU activation is supported."
if not VLLM_AVAILABLE:
raise ImportError(