Support custom DeepEP tuning config (#6257)

This commit is contained in:
fzyzcjy
2025-05-18 08:09:42 +08:00
committed by GitHub
parent 26ebb849eb
commit fd08c04821
6 changed files with 79 additions and 5 deletions

View File

@@ -1,8 +1,11 @@
import logging
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.srt.utils import DeepEPMode
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import DeepEPMode, load_json_config
try:
from deep_ep import Buffer
from deep_ep import Buffer, Config
from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8,
@@ -25,6 +28,8 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode
logger = logging.getLogger(__name__)
class DeepEPDispatchMode(IntEnum):
NORMAL = auto()
@@ -32,7 +37,6 @@ class DeepEPDispatchMode(IntEnum):
class DeepEPBuffer:
_buffer = None
_dispatch_mode: Optional[DeepEPDispatchMode] = None
_hidden_size: Optional[int] = None
@@ -60,8 +64,10 @@ class DeepEPBuffer:
if deepep_mode.enable_normal():
hidden_bytes = hidden_size * param_bytes
for config in (
Buffer.get_dispatch_config(group.size()),
Buffer.get_combine_config(group.size()),
_DeepEPConfig.get_instance().normal_dispatch_config
or Buffer.get_dispatch_config(group.size()),
_DeepEPConfig.get_instance().normal_combine_config
or Buffer.get_combine_config(group.size()),
):
num_nvl_bytes = max(
config.get_nvl_buffer_size_hint(hidden_bytes, group.size()),
@@ -113,6 +119,28 @@ class DeepEPBuffer:
cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
class _DeepEPConfig:
_instance = None
def __init__(self):
config_str = global_server_args_dict["deepep_config"]
if config_str:
config_parsed = load_json_config(config_str)
if torch.distributed.get_rank() == 0:
logger.info(f"Use DeepEP Config: {config_parsed}")
self.normal_dispatch_config = Config(**config_parsed["normal_dispatch"])
self.normal_combine_config = Config(**config_parsed["normal_combine"])
else:
self.normal_dispatch_config = None
self.normal_combine_config = None
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = _DeepEPConfig()
return cls._instance
class _DeepEPDispatcherImplBase:
def __init__(
self,
@@ -295,6 +323,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
async_finish=self.async_finish,
allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
expert_alignment=128 if _ENABLE_JIT_DEEPGEMM else 1,
config=_DeepEPConfig.get_instance().normal_dispatch_config,
)
return (
@@ -394,6 +423,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
async_finish=self.async_finish,
previous_event=previous_event,
allocate_on_comm_stream=previous_event is not None,
config=_DeepEPConfig.get_instance().normal_combine_config,
)
return combined_x, event

View File

@@ -77,6 +77,7 @@ global_server_args_dict = {
"enable_dp_attention": ServerArgs.enable_dp_attention,
"enable_dp_lm_head": ServerArgs.enable_dp_lm_head,
"enable_ep_moe": ServerArgs.enable_ep_moe,
"deepep_config": ServerArgs.deepep_config,
"enable_nan_detection": ServerArgs.enable_nan_detection,
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
"max_micro_batch_size": ServerArgs.max_micro_batch_size,

View File

@@ -165,6 +165,7 @@ class ModelRunner:
"enable_dp_attention": server_args.enable_dp_attention,
"enable_ep_moe": server_args.enable_ep_moe,
"enable_deepep_moe": server_args.enable_deepep_moe,
"deepep_config": server_args.deepep_config,
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
"moe_dense_tp_size": server_args.moe_dense_tp_size,
"n_share_experts_fusion": server_args.n_share_experts_fusion,

View File

@@ -169,6 +169,7 @@ class ServerArgs:
enable_ep_moe: bool = False
enable_deepep_moe: bool = False
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
deepep_config: Optional[str] = None
enable_torch_compile: bool = False
torch_compile_max_bs: int = 32
cuda_graph_max_bs: Optional[int] = None
@@ -1249,6 +1250,12 @@ class ServerArgs:
default="auto",
help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
)
parser.add_argument(
"--deepep-config",
type=str,
default=ServerArgs.deepep_config,
help="Tuned DeepEP config suitable for your own cluster.",
)
parser.add_argument(
"--n-share-experts-fusion",

View File

@@ -2102,5 +2102,12 @@ def log_info_on_rank0(logger, msg):
logger.info(msg)
def load_json_config(data: str):
try:
return json.loads(data)
except JSONDecodeError:
return json.loads(Path(data).read_text())
def dispose_tensor(x: torch.Tensor):
x.set_(torch.empty((0,), device=x.device, dtype=x.dtype))