Reorg moe code (#2563)
This commit is contained in:
@@ -60,8 +60,8 @@ def fp8_get_quant_method(self, layer, prefix):
|
||||
is_layer_skipped,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.linear import UnquantizedLinearMethod
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod, Fp8MoEMethod
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
@@ -80,7 +80,7 @@ def gptq_get_quant_method(self, layer, prefix):
|
||||
GPTQMarlinMoEMethod,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
return GPTQMarlinLinearMethod(self)
|
||||
@@ -96,7 +96,7 @@ def awq_get_quant_method(self, layer, prefix):
|
||||
AWQMoEMethod,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
return AWQMarlinLinearMethod(self)
|
||||
|
||||
@@ -26,8 +26,8 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
)
|
||||
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||
|
||||
from sglang.srt.layers.fused_moe_triton.fused_moe import padding_size
|
||||
from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import padding_size
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
@@ -98,7 +98,7 @@ class Fp8Config(QuantizationConfig):
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
|
||||
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(prefix, self.ignored_layers):
|
||||
@@ -320,7 +320,7 @@ class Fp8MoEMethod:
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
from sglang.srt.layers.fused_moe_triton import FusedMoEMethodBase
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
||||
|
||||
if not hasattr(cls, "_initialized"):
|
||||
original_init = cls.__init__
|
||||
@@ -349,7 +349,7 @@ class Fp8MoEMethod:
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
from sglang.srt.layers.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
@@ -566,12 +566,14 @@ class Fp8MoEMethod:
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
# Expert selection
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
@@ -580,6 +582,7 @@ class Fp8MoEMethod:
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
)
|
||||
|
||||
# Expert fusion with FP8 quantization
|
||||
|
||||
Reference in New Issue
Block a user