Add DeepSeek V3/R1 shared experts fusion (#4918)
This commit is contained in:
@@ -51,7 +51,6 @@ except ImportError:
|
||||
|
||||
|
||||
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.quantization.awq import AWQConfig
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
|
||||
@@ -203,6 +202,8 @@ def get_linear_quant_method(
|
||||
|
||||
|
||||
def gptq_get_quant_method(self, layer, prefix):
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
return GPTQMarlinMoEMethod(self)
|
||||
|
||||
|
||||
@@ -23,7 +23,6 @@ from sglang.srt.layers.linear import (
|
||||
LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
@@ -123,6 +122,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
return UnquantizedLinearMethod()
|
||||
layer.scheme = scheme
|
||||
return CompressedTensorsLinearMethod(self)
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
return CompressedTensorsMoEMethod.get_moe_method(self)
|
||||
return None
|
||||
|
||||
@@ -4,18 +4,19 @@
|
||||
import enum
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Callable, List, Optional
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors import CompressionFormat
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
|
||||
from sglang.srt.layers.moe.fused_moe_triton import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported,
|
||||
)
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.fused_moe_triton import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
||||
from sglang.srt.layers.quantization.utils import (
|
||||
all_close_1d,
|
||||
@@ -55,7 +56,13 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
class CompressedTensorsMoEMethod:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
||||
|
||||
if cls is CompressedTensorsMoEMethod:
|
||||
return super().__new__(cls)
|
||||
return super().__new__(cls)
|
||||
|
||||
@staticmethod
|
||||
def get_moe_method(
|
||||
@@ -85,6 +92,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
def __init__(
|
||||
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
||||
):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import (
|
||||
FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported,
|
||||
)
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
|
||||
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
|
||||
@@ -112,6 +124,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
|
||||
@@ -270,8 +283,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
scoring_func: str = "softmax",
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
from sglang.srt.layers.moe.fused_moe_triton import fused_experts
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
@@ -291,7 +307,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
@@ -306,6 +322,11 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
def __init__(
|
||||
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
||||
):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import (
|
||||
FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported,
|
||||
)
|
||||
|
||||
self.quant_config = quant_config
|
||||
# TODO: @dsikka: refactor this to use schemes as other kernels
|
||||
# are supported + check if the layer is being ignored.
|
||||
@@ -617,6 +638,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
if not VLLM_AVAILABLE:
|
||||
raise ImportError(
|
||||
|
||||
Reference in New Issue
Block a user