Reorg moe code (#2563)

This commit is contained in:
Ke Bao
2024-12-24 01:10:22 +08:00
committed by GitHub
parent 23e5e50fd5
commit e835a50021
88 changed files with 338 additions and 344 deletions

View File

@@ -60,8 +60,8 @@ def fp8_get_quant_method(self, layer, prefix):
is_layer_skipped,
)
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.linear import UnquantizedLinearMethod
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod, Fp8MoEMethod
if isinstance(layer, LinearBase):
@@ -80,7 +80,7 @@ def gptq_get_quant_method(self, layer, prefix):
GPTQMarlinMoEMethod,
)
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
if isinstance(layer, LinearBase):
return GPTQMarlinLinearMethod(self)
@@ -96,7 +96,7 @@ def awq_get_quant_method(self, layer, prefix):
AWQMoEMethod,
)
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
if isinstance(layer, LinearBase):
return AWQMarlinLinearMethod(self)

View File

@@ -26,8 +26,8 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
)
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.fused_moe_triton.fused_moe import padding_size
from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import padding_size
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
@@ -98,7 +98,7 @@ class Fp8Config(QuantizationConfig):
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers):
@@ -320,7 +320,7 @@ class Fp8MoEMethod:
"""
def __new__(cls, *args, **kwargs):
from sglang.srt.layers.fused_moe_triton import FusedMoEMethodBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
@@ -349,7 +349,7 @@ class Fp8MoEMethod:
params_dtype: torch.dtype,
**extra_weight_attrs,
):
from sglang.srt.layers.fused_moe_triton import FusedMoeWeightScaleSupported
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn
@@ -566,12 +566,14 @@ class Fp8MoEMethod:
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts
# Expert selection
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
@@ -580,6 +582,7 @@ class Fp8MoEMethod:
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
)
# Expert fusion with FP8 quantization