Support custom DeepEP tuning config (#6257)
This commit is contained in:
@@ -1,8 +1,11 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
||||||
from sglang.srt.utils import DeepEPMode
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
|
from sglang.srt.utils import DeepEPMode, load_json_config
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from deep_ep import Buffer
|
from deep_ep import Buffer, Config
|
||||||
|
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
sglang_per_token_group_quant_fp8,
|
sglang_per_token_group_quant_fp8,
|
||||||
@@ -25,6 +28,8 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DeepEPDispatchMode(IntEnum):
|
class DeepEPDispatchMode(IntEnum):
|
||||||
NORMAL = auto()
|
NORMAL = auto()
|
||||||
@@ -32,7 +37,6 @@ class DeepEPDispatchMode(IntEnum):
|
|||||||
|
|
||||||
|
|
||||||
class DeepEPBuffer:
|
class DeepEPBuffer:
|
||||||
|
|
||||||
_buffer = None
|
_buffer = None
|
||||||
_dispatch_mode: Optional[DeepEPDispatchMode] = None
|
_dispatch_mode: Optional[DeepEPDispatchMode] = None
|
||||||
_hidden_size: Optional[int] = None
|
_hidden_size: Optional[int] = None
|
||||||
@@ -60,8 +64,10 @@ class DeepEPBuffer:
|
|||||||
if deepep_mode.enable_normal():
|
if deepep_mode.enable_normal():
|
||||||
hidden_bytes = hidden_size * param_bytes
|
hidden_bytes = hidden_size * param_bytes
|
||||||
for config in (
|
for config in (
|
||||||
Buffer.get_dispatch_config(group.size()),
|
_DeepEPConfig.get_instance().normal_dispatch_config
|
||||||
Buffer.get_combine_config(group.size()),
|
or Buffer.get_dispatch_config(group.size()),
|
||||||
|
_DeepEPConfig.get_instance().normal_combine_config
|
||||||
|
or Buffer.get_combine_config(group.size()),
|
||||||
):
|
):
|
||||||
num_nvl_bytes = max(
|
num_nvl_bytes = max(
|
||||||
config.get_nvl_buffer_size_hint(hidden_bytes, group.size()),
|
config.get_nvl_buffer_size_hint(hidden_bytes, group.size()),
|
||||||
@@ -113,6 +119,28 @@ class DeepEPBuffer:
|
|||||||
cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
|
cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
|
||||||
|
|
||||||
|
|
||||||
|
class _DeepEPConfig:
|
||||||
|
_instance = None
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
config_str = global_server_args_dict["deepep_config"]
|
||||||
|
if config_str:
|
||||||
|
config_parsed = load_json_config(config_str)
|
||||||
|
if torch.distributed.get_rank() == 0:
|
||||||
|
logger.info(f"Use DeepEP Config: {config_parsed}")
|
||||||
|
self.normal_dispatch_config = Config(**config_parsed["normal_dispatch"])
|
||||||
|
self.normal_combine_config = Config(**config_parsed["normal_combine"])
|
||||||
|
else:
|
||||||
|
self.normal_dispatch_config = None
|
||||||
|
self.normal_combine_config = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_instance(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = _DeepEPConfig()
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
|
||||||
class _DeepEPDispatcherImplBase:
|
class _DeepEPDispatcherImplBase:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -295,6 +323,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
async_finish=self.async_finish,
|
async_finish=self.async_finish,
|
||||||
allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
|
allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
|
||||||
expert_alignment=128 if _ENABLE_JIT_DEEPGEMM else 1,
|
expert_alignment=128 if _ENABLE_JIT_DEEPGEMM else 1,
|
||||||
|
config=_DeepEPConfig.get_instance().normal_dispatch_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@@ -394,6 +423,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
async_finish=self.async_finish,
|
async_finish=self.async_finish,
|
||||||
previous_event=previous_event,
|
previous_event=previous_event,
|
||||||
allocate_on_comm_stream=previous_event is not None,
|
allocate_on_comm_stream=previous_event is not None,
|
||||||
|
config=_DeepEPConfig.get_instance().normal_combine_config,
|
||||||
)
|
)
|
||||||
return combined_x, event
|
return combined_x, event
|
||||||
|
|
||||||
|
|||||||
@@ -77,6 +77,7 @@ global_server_args_dict = {
|
|||||||
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
||||||
"enable_dp_lm_head": ServerArgs.enable_dp_lm_head,
|
"enable_dp_lm_head": ServerArgs.enable_dp_lm_head,
|
||||||
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
||||||
|
"deepep_config": ServerArgs.deepep_config,
|
||||||
"enable_nan_detection": ServerArgs.enable_nan_detection,
|
"enable_nan_detection": ServerArgs.enable_nan_detection,
|
||||||
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
|
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
|
||||||
"max_micro_batch_size": ServerArgs.max_micro_batch_size,
|
"max_micro_batch_size": ServerArgs.max_micro_batch_size,
|
||||||
|
|||||||
@@ -165,6 +165,7 @@ class ModelRunner:
|
|||||||
"enable_dp_attention": server_args.enable_dp_attention,
|
"enable_dp_attention": server_args.enable_dp_attention,
|
||||||
"enable_ep_moe": server_args.enable_ep_moe,
|
"enable_ep_moe": server_args.enable_ep_moe,
|
||||||
"enable_deepep_moe": server_args.enable_deepep_moe,
|
"enable_deepep_moe": server_args.enable_deepep_moe,
|
||||||
|
"deepep_config": server_args.deepep_config,
|
||||||
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
||||||
"moe_dense_tp_size": server_args.moe_dense_tp_size,
|
"moe_dense_tp_size": server_args.moe_dense_tp_size,
|
||||||
"n_share_experts_fusion": server_args.n_share_experts_fusion,
|
"n_share_experts_fusion": server_args.n_share_experts_fusion,
|
||||||
|
|||||||
@@ -169,6 +169,7 @@ class ServerArgs:
|
|||||||
enable_ep_moe: bool = False
|
enable_ep_moe: bool = False
|
||||||
enable_deepep_moe: bool = False
|
enable_deepep_moe: bool = False
|
||||||
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
|
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
|
||||||
|
deepep_config: Optional[str] = None
|
||||||
enable_torch_compile: bool = False
|
enable_torch_compile: bool = False
|
||||||
torch_compile_max_bs: int = 32
|
torch_compile_max_bs: int = 32
|
||||||
cuda_graph_max_bs: Optional[int] = None
|
cuda_graph_max_bs: Optional[int] = None
|
||||||
@@ -1249,6 +1250,12 @@ class ServerArgs:
|
|||||||
default="auto",
|
default="auto",
|
||||||
help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
|
help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--deepep-config",
|
||||||
|
type=str,
|
||||||
|
default=ServerArgs.deepep_config,
|
||||||
|
help="Tuned DeepEP config suitable for your own cluster.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--n-share-experts-fusion",
|
"--n-share-experts-fusion",
|
||||||
|
|||||||
@@ -2102,5 +2102,12 @@ def log_info_on_rank0(logger, msg):
|
|||||||
logger.info(msg)
|
logger.info(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def load_json_config(data: str):
|
||||||
|
try:
|
||||||
|
return json.loads(data)
|
||||||
|
except JSONDecodeError:
|
||||||
|
return json.loads(Path(data).read_text())
|
||||||
|
|
||||||
|
|
||||||
def dispose_tensor(x: torch.Tensor):
|
def dispose_tensor(x: torch.Tensor):
|
||||||
x.set_(torch.empty((0,), device=x.device, dtype=x.dtype))
|
x.set_(torch.empty((0,), device=x.device, dtype=x.dtype))
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
@@ -64,8 +66,34 @@ class TestDPAttn(unittest.TestCase):
|
|||||||
"2",
|
"2",
|
||||||
"--enable-dp-attention",
|
"--enable-dp-attention",
|
||||||
"--enable-deepep-moe",
|
"--enable-deepep-moe",
|
||||||
|
"--deepep-mode",
|
||||||
|
"normal",
|
||||||
"--disable-cuda-graph",
|
"--disable-cuda-graph",
|
||||||
|
# Test custom config
|
||||||
|
"--deepep-config",
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"normal_dispatch": {
|
||||||
|
"num_sms": 20,
|
||||||
|
"num_max_nvl_chunked_send_tokens": 16,
|
||||||
|
"num_max_nvl_chunked_recv_tokens": 256,
|
||||||
|
"num_max_rdma_chunked_send_tokens": 6,
|
||||||
|
"num_max_rdma_chunked_recv_tokens": 128,
|
||||||
|
},
|
||||||
|
"normal_combine": {
|
||||||
|
"num_sms": 20,
|
||||||
|
"num_max_nvl_chunked_send_tokens": 6,
|
||||||
|
"num_max_nvl_chunked_recv_tokens": 256,
|
||||||
|
"num_max_rdma_chunked_send_tokens": 6,
|
||||||
|
"num_max_rdma_chunked_recv_tokens": 128,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
),
|
||||||
],
|
],
|
||||||
|
env={
|
||||||
|
"SGL_ENABLE_JIT_DEEPGEMM": "0",
|
||||||
|
**os.environ,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user