fix: resolve fp8 moe issue (#2387)
This commit is contained in:
@@ -22,7 +22,7 @@ from vllm.model_executor.layers.quantization.qqq import QQQConfig
|
|||||||
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
|
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
|
||||||
|
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
||||||
|
|
||||||
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||||
"aqlm": AQLMConfig,
|
"aqlm": AQLMConfig,
|
||||||
@@ -53,50 +53,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
|||||||
return QUANTIZATION_METHODS[quantization]
|
return QUANTIZATION_METHODS[quantization]
|
||||||
|
|
||||||
|
|
||||||
def fp8_moe_apply(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
x: torch.Tensor,
|
|
||||||
router_logits: torch.Tensor,
|
|
||||||
top_k: int,
|
|
||||||
renormalize: bool,
|
|
||||||
use_grouped_topk: bool,
|
|
||||||
topk_group: Optional[int] = None,
|
|
||||||
num_expert_group: Optional[int] = None,
|
|
||||||
custom_routing_function: Optional[Callable] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Enhanced apply method for FP8 MoE."""
|
|
||||||
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
|
||||||
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
|
|
||||||
|
|
||||||
# Expert selection
|
|
||||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
|
||||||
hidden_states=x,
|
|
||||||
router_logits=router_logits,
|
|
||||||
use_grouped_topk=use_grouped_topk,
|
|
||||||
top_k=top_k,
|
|
||||||
renormalize=renormalize,
|
|
||||||
topk_group=topk_group,
|
|
||||||
num_expert_group=num_expert_group,
|
|
||||||
custom_routing_function=custom_routing_function,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Expert fusion with FP8 quantization
|
|
||||||
return fused_experts(
|
|
||||||
x,
|
|
||||||
layer.w13_weight,
|
|
||||||
layer.w2_weight,
|
|
||||||
topk_weights=topk_weights,
|
|
||||||
topk_ids=topk_ids,
|
|
||||||
inplace=True,
|
|
||||||
use_fp8_w8a8=True,
|
|
||||||
w1_scale=layer.w13_weight_scale,
|
|
||||||
w2_scale=layer.w2_weight_scale,
|
|
||||||
a1_scale=layer.w13_input_scale,
|
|
||||||
a2_scale=layer.w2_input_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def fp8_get_quant_method(self, layer, prefix):
|
def fp8_get_quant_method(self, layer, prefix):
|
||||||
"""Enhanced get_quant_method for FP8 config."""
|
"""Enhanced get_quant_method for FP8 config."""
|
||||||
from vllm.model_executor.layers.linear import LinearBase
|
from vllm.model_executor.layers.linear import LinearBase
|
||||||
@@ -106,7 +62,7 @@ def fp8_get_quant_method(self, layer, prefix):
|
|||||||
|
|
||||||
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
|
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
|
||||||
from sglang.srt.layers.linear import UnquantizedLinearMethod
|
from sglang.srt.layers.linear import UnquantizedLinearMethod
|
||||||
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
|
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod, Fp8MoEMethod
|
||||||
|
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
if is_layer_skipped(prefix, self.ignored_layers):
|
if is_layer_skipped(prefix, self.ignored_layers):
|
||||||
@@ -151,7 +107,6 @@ def awq_get_quant_method(self, layer, prefix):
|
|||||||
|
|
||||||
def apply_monkey_patches():
|
def apply_monkey_patches():
|
||||||
"""Apply all monkey patches in one place."""
|
"""Apply all monkey patches in one place."""
|
||||||
setattr(Fp8MoEMethod, "apply", fp8_moe_apply)
|
|
||||||
setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
|
setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
|
||||||
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
||||||
setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
|
setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
|
||||||
|
|||||||
@@ -24,11 +24,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
|
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||||
|
|
||||||
from sglang.srt.layers.fused_moe_triton import (
|
|
||||||
FusedMoE,
|
|
||||||
FusedMoEMethodBase,
|
|
||||||
FusedMoeWeightScaleSupported,
|
|
||||||
)
|
|
||||||
from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
|
from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
@@ -100,6 +95,8 @@ class Fp8Config(QuantizationConfig):
|
|||||||
) -> Optional["QuantizeMethodBase"]:
|
) -> Optional["QuantizeMethodBase"]:
|
||||||
from vllm.attention.layer import Attention # Avoid circular import
|
from vllm.attention.layer import Attention # Avoid circular import
|
||||||
|
|
||||||
|
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
||||||
|
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
if is_layer_skipped(prefix, self.ignored_layers):
|
if is_layer_skipped(prefix, self.ignored_layers):
|
||||||
return UnquantizedLinearMethod()
|
return UnquantizedLinearMethod()
|
||||||
@@ -306,7 +303,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Fp8MoEMethod(FusedMoEMethodBase):
|
class Fp8MoEMethod:
|
||||||
"""MoE method for FP8.
|
"""MoE method for FP8.
|
||||||
Supports loading FP8 checkpoints with static weight scale and
|
Supports loading FP8 checkpoints with static weight scale and
|
||||||
dynamic/static activation scale.
|
dynamic/static activation scale.
|
||||||
@@ -319,7 +316,25 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
quant_config: The quantization config.
|
quant_config: The quantization config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, quant_config: Fp8Config):
|
def __new__(cls, *args, **kwargs):
|
||||||
|
from sglang.srt.layers.fused_moe_triton import FusedMoEMethodBase
|
||||||
|
|
||||||
|
if not hasattr(cls, "_initialized"):
|
||||||
|
original_init = cls.__init__
|
||||||
|
new_cls = type(
|
||||||
|
cls.__name__,
|
||||||
|
(FusedMoEMethodBase,),
|
||||||
|
{
|
||||||
|
"__init__": original_init,
|
||||||
|
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
obj = super(new_cls, new_cls).__new__(new_cls)
|
||||||
|
obj.__init__(*args, **kwargs)
|
||||||
|
return obj
|
||||||
|
return super().__new__(cls)
|
||||||
|
|
||||||
|
def __init__(self, quant_config):
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
@@ -331,6 +346,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
params_dtype: torch.dtype,
|
params_dtype: torch.dtype,
|
||||||
**extra_weight_attrs,
|
**extra_weight_attrs,
|
||||||
):
|
):
|
||||||
|
from sglang.srt.layers.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||||
|
|
||||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
params_dtype = torch.float8_e4m3fn
|
params_dtype = torch.float8_e4m3fn
|
||||||
@@ -521,8 +537,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
||||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
|
||||||
|
|
||||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
|
|||||||
Reference in New Issue
Block a user