Small refactor DeepEPMode to clean up code a bit (#4992)
This commit is contained in:
@@ -38,7 +38,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||||
from sglang.srt.utils import is_cuda, is_hip, set_weight_attrs
|
from sglang.srt.utils import DeepEPMode, is_cuda, is_hip, set_weight_attrs
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
|
|
||||||
@@ -47,7 +47,6 @@ if _is_cuda:
|
|||||||
else:
|
else:
|
||||||
from vllm import _custom_ops as vllm_ops
|
from vllm import _custom_ops as vllm_ops
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
@@ -814,7 +813,7 @@ class DeepEPMoE(EPMoE):
|
|||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
deepep_mode: str = "auto",
|
deepep_mode: DeepEPMode = DeepEPMode.auto,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_experts,
|
num_experts,
|
||||||
@@ -834,7 +833,7 @@ class DeepEPMoE(EPMoE):
|
|||||||
activation,
|
activation,
|
||||||
)
|
)
|
||||||
self.deepep_mode = deepep_mode
|
self.deepep_mode = deepep_mode
|
||||||
if self.deepep_mode in ["low_latency", "auto"]:
|
if self.deepep_mode.enable_low_latency():
|
||||||
assert use_deep_gemm, f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
assert use_deep_gemm, f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
||||||
self.w13_weight_fp8 = (
|
self.w13_weight_fp8 = (
|
||||||
self.w13_weight,
|
self.w13_weight,
|
||||||
@@ -858,13 +857,10 @@ class DeepEPMoE(EPMoE):
|
|||||||
expected_m: int,
|
expected_m: int,
|
||||||
forward_mode: ForwardMode,
|
forward_mode: ForwardMode,
|
||||||
):
|
):
|
||||||
if self.deepep_mode == "normal" or (
|
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
|
||||||
self.deepep_mode == "auto" and not forward_mode.is_decode()
|
if resolved_deepep_mode == DeepEPMode.normal:
|
||||||
):
|
|
||||||
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
|
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
|
||||||
elif self.deepep_mode == "low_latency" or (
|
elif resolved_deepep_mode == DeepEPMode.low_latency:
|
||||||
self.deepep_mode == "auto" and forward_mode.is_decode()
|
|
||||||
):
|
|
||||||
return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
|
return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
|
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from sglang.srt.utils import DeepEPMode
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from deep_ep import Buffer
|
from deep_ep import Buffer
|
||||||
|
|
||||||
@@ -98,7 +100,7 @@ class DeepEPDispatcher:
|
|||||||
num_local_experts: int = None,
|
num_local_experts: int = None,
|
||||||
hidden_size: int = None,
|
hidden_size: int = None,
|
||||||
params_dtype: torch.dtype = None,
|
params_dtype: torch.dtype = None,
|
||||||
deepep_mode: str = "auto",
|
deepep_mode: DeepEPMode = DeepEPMode.auto,
|
||||||
async_finish: bool = False,
|
async_finish: bool = False,
|
||||||
return_recv_hook: bool = False,
|
return_recv_hook: bool = False,
|
||||||
):
|
):
|
||||||
@@ -120,13 +122,13 @@ class DeepEPDispatcher:
|
|||||||
self.deepep_mode = deepep_mode
|
self.deepep_mode = deepep_mode
|
||||||
self.handle = None
|
self.handle = None
|
||||||
|
|
||||||
if self.deepep_mode in ["normal", "auto"]: # for normal / auto mode
|
if self.deepep_mode.enable_normal():
|
||||||
self.buffer_normal = get_buffer_normal(
|
self.buffer_normal = get_buffer_normal(
|
||||||
self.group, self.hidden_size * self.params_bytes
|
self.group, self.hidden_size * self.params_bytes
|
||||||
)
|
)
|
||||||
self.async_finish = async_finish
|
self.async_finish = async_finish
|
||||||
self.src2dst = None
|
self.src2dst = None
|
||||||
if self.deepep_mode in ["low_latency", "auto"]: # for low_latency / auto mode
|
if self.deepep_mode.enable_low_latency():
|
||||||
"""
|
"""
|
||||||
num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
|
num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
|
||||||
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
|
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
|
||||||
@@ -196,9 +198,8 @@ class DeepEPDispatcher:
|
|||||||
)
|
)
|
||||||
expected_m = 0
|
expected_m = 0
|
||||||
|
|
||||||
if self.deepep_mode == "normal" or (
|
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
|
||||||
self.deepep_mode == "auto" and not forward_mode.is_decode()
|
if resolved_deepep_mode == DeepEPMode.normal:
|
||||||
):
|
|
||||||
(
|
(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
topk_idx,
|
topk_idx,
|
||||||
@@ -210,9 +211,7 @@ class DeepEPDispatcher:
|
|||||||
reorder_topk_ids, seg_indptr, hidden_states = self.deepep_permute(
|
reorder_topk_ids, seg_indptr, hidden_states = self.deepep_permute(
|
||||||
hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
|
hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
|
||||||
)
|
)
|
||||||
elif self.deepep_mode == "low_latency" or (
|
elif resolved_deepep_mode == DeepEPMode.low_latency:
|
||||||
self.deepep_mode == "auto" and forward_mode.is_decode()
|
|
||||||
):
|
|
||||||
expected_m = (
|
expected_m = (
|
||||||
hidden_states.shape[0]
|
hidden_states.shape[0]
|
||||||
* self.buffer_low_latency.group_size
|
* self.buffer_low_latency.group_size
|
||||||
@@ -354,9 +353,8 @@ class DeepEPDispatcher:
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
forward_mode: ForwardMode,
|
forward_mode: ForwardMode,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if self.deepep_mode == "normal" or (
|
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
|
||||||
self.deepep_mode == "auto" and not forward_mode.is_decode()
|
if resolved_deepep_mode == DeepEPMode.normal:
|
||||||
):
|
|
||||||
if hidden_states.shape[0] > 0:
|
if hidden_states.shape[0] > 0:
|
||||||
num_tokens = self.src2dst.shape[0] // self.router_topk
|
num_tokens = self.src2dst.shape[0] // self.router_topk
|
||||||
output = torch.empty(
|
output = torch.empty(
|
||||||
@@ -384,9 +382,7 @@ class DeepEPDispatcher:
|
|||||||
output,
|
output,
|
||||||
)
|
)
|
||||||
event.current_stream_wait() if self.async_finish else ()
|
event.current_stream_wait() if self.async_finish else ()
|
||||||
elif self.deepep_mode == "low_latency" or (
|
elif resolved_deepep_mode == DeepEPMode.low_latency:
|
||||||
self.deepep_mode == "auto" and forward_mode.is_decode()
|
|
||||||
):
|
|
||||||
hidden_states, event, hook = self.combine_low_latency(
|
hidden_states, event, hook = self.combine_low_latency(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
topk_idx,
|
topk_idx,
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
|||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.utils import add_prefix, is_cuda, is_hip
|
from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
@@ -215,7 +215,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
topk_group=config.topk_group,
|
topk_group=config.topk_group,
|
||||||
correction_bias=self.gate.e_score_correction_bias,
|
correction_bias=self.gate.e_score_correction_bias,
|
||||||
prefix=add_prefix("experts", prefix),
|
prefix=add_prefix("experts", prefix),
|
||||||
deepep_mode=global_server_args_dict["deepep_mode"],
|
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.n_shared_experts is not None:
|
if config.n_shared_experts is not None:
|
||||||
@@ -264,7 +264,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
num_local_experts=config.n_routed_experts // self.tp_size,
|
num_local_experts=config.n_routed_experts // self.tp_size,
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
params_dtype=config.torch_dtype,
|
params_dtype=config.torch_dtype,
|
||||||
deepep_mode=global_server_args_dict["deepep_mode"],
|
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
|
||||||
async_finish=True, # TODO
|
async_finish=True, # TODO
|
||||||
return_recv_hook=True,
|
return_recv_hook=True,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import List, Optional
|
from typing import List, Literal, Optional
|
||||||
|
|
||||||
from sglang.srt.hf_transformers_utils import check_gguf_file
|
from sglang.srt.hf_transformers_utils import check_gguf_file
|
||||||
from sglang.srt.reasoning_parser import ReasoningParser
|
from sglang.srt.reasoning_parser import ReasoningParser
|
||||||
@@ -161,7 +161,7 @@ class ServerArgs:
|
|||||||
enable_dp_attention: bool = False
|
enable_dp_attention: bool = False
|
||||||
enable_ep_moe: bool = False
|
enable_ep_moe: bool = False
|
||||||
enable_deepep_moe: bool = False
|
enable_deepep_moe: bool = False
|
||||||
deepep_mode: Optional[str] = "auto"
|
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
|
||||||
enable_torch_compile: bool = False
|
enable_torch_compile: bool = False
|
||||||
torch_compile_max_bs: int = 32
|
torch_compile_max_bs: int = 32
|
||||||
cuda_graph_max_bs: Optional[int] = None
|
cuda_graph_max_bs: Optional[int] = None
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from enum import Enum
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from importlib.metadata import PackageNotFoundError, version
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
from importlib.util import find_spec
|
from importlib.util import find_spec
|
||||||
@@ -1838,3 +1839,24 @@ def flatten_nested_list(nested_list):
|
|||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
return [nested_list]
|
return [nested_list]
|
||||||
|
|
||||||
|
|
||||||
|
class DeepEPMode(Enum):
|
||||||
|
normal = "normal"
|
||||||
|
low_latency = "low_latency"
|
||||||
|
auto = "auto"
|
||||||
|
|
||||||
|
def enable_normal(self):
|
||||||
|
return self in [DeepEPMode.normal, DeepEPMode.auto]
|
||||||
|
|
||||||
|
def enable_low_latency(self):
|
||||||
|
return self in [DeepEPMode.low_latency, DeepEPMode.auto]
|
||||||
|
|
||||||
|
def resolve(self, forward_mode):
|
||||||
|
if self != DeepEPMode.auto:
|
||||||
|
return self
|
||||||
|
|
||||||
|
if forward_mode.is_decode():
|
||||||
|
return DeepEPMode.low_latency
|
||||||
|
else:
|
||||||
|
return DeepEPMode.normal
|
||||||
|
|||||||
Reference in New Issue
Block a user