Small refactor DeepEPMode to clean up code a bit (#4992)

This commit is contained in:
fzyzcjy
2025-04-03 17:56:44 +08:00
committed by GitHub
parent e8999b13b7
commit 8e10fec9a8
5 changed files with 44 additions and 30 deletions

View File

@@ -38,7 +38,7 @@ from sglang.srt.layers.quantization.base_config import (
)
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import is_cuda, is_hip, set_weight_attrs
from sglang.srt.utils import DeepEPMode, is_cuda, is_hip, set_weight_attrs
_is_cuda = is_cuda()
@@ -47,7 +47,6 @@ if _is_cuda:
else:
from vllm import _custom_ops as vllm_ops
logger = logging.getLogger(__name__)
_is_hip = is_hip()
@@ -814,7 +813,7 @@ class DeepEPMoE(EPMoE):
correction_bias: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
activation: str = "silu",
deepep_mode: str = "auto",
deepep_mode: DeepEPMode = DeepEPMode.auto,
):
super().__init__(
num_experts,
@@ -834,7 +833,7 @@ class DeepEPMoE(EPMoE):
activation,
)
self.deepep_mode = deepep_mode
if self.deepep_mode in ["low_latency", "auto"]:
if self.deepep_mode.enable_low_latency():
assert use_deep_gemm, f"DeepEP {self.deepep_mode} mode requires deep_gemm"
self.w13_weight_fp8 = (
self.w13_weight,
@@ -858,13 +857,10 @@ class DeepEPMoE(EPMoE):
expected_m: int,
forward_mode: ForwardMode,
):
if self.deepep_mode == "normal" or (
self.deepep_mode == "auto" and not forward_mode.is_decode()
):
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
if resolved_deepep_mode == DeepEPMode.normal:
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
elif self.deepep_mode == "low_latency" or (
self.deepep_mode == "auto" and forward_mode.is_decode()
):
elif resolved_deepep_mode == DeepEPMode.low_latency:
return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
else:
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")