fix: resolve fp8 moe issue (#2387)

This commit is contained in:
Yineng Zhang
2024-12-07 19:28:53 +08:00
committed by GitHub
parent c36736c841
commit d332aa3b0c
2 changed files with 27 additions and 56 deletions

View File

@@ -24,11 +24,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
)
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.fused_moe_triton import (
FusedMoE,
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
)
from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
@@ -100,6 +95,8 @@ class Fp8Config(QuantizationConfig):
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
from sglang.srt.layers.fused_moe_triton import FusedMoE
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers):
return UnquantizedLinearMethod()
@@ -306,7 +303,7 @@ class Fp8LinearMethod(LinearMethodBase):
)
class Fp8MoEMethod(FusedMoEMethodBase):
class Fp8MoEMethod:
"""MoE method for FP8.
Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale.
@@ -319,7 +316,25 @@ class Fp8MoEMethod(FusedMoEMethodBase):
quant_config: The quantization config.
"""
def __init__(self, quant_config: Fp8Config):
def __new__(cls, *args, **kwargs):
from sglang.srt.layers.fused_moe_triton import FusedMoEMethodBase
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
def __init__(self, quant_config):
self.quant_config = quant_config
def create_weights(
@@ -331,6 +346,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
params_dtype: torch.dtype,
**extra_weight_attrs,
):
from sglang.srt.layers.fused_moe_triton import FusedMoeWeightScaleSupported
if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn
@@ -521,8 +537,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,