Support custom DeepEP tuning config (#6257)
This commit is contained in:
@@ -1,8 +1,11 @@
|
||||
import logging
|
||||
|
||||
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
||||
from sglang.srt.utils import DeepEPMode
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.utils import DeepEPMode, load_json_config
|
||||
|
||||
try:
|
||||
from deep_ep import Buffer
|
||||
from deep_ep import Buffer, Config
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
sglang_per_token_group_quant_fp8,
|
||||
@@ -25,6 +28,8 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DeepEPDispatchMode(IntEnum):
|
||||
NORMAL = auto()
|
||||
@@ -32,7 +37,6 @@ class DeepEPDispatchMode(IntEnum):
|
||||
|
||||
|
||||
class DeepEPBuffer:
|
||||
|
||||
_buffer = None
|
||||
_dispatch_mode: Optional[DeepEPDispatchMode] = None
|
||||
_hidden_size: Optional[int] = None
|
||||
@@ -60,8 +64,10 @@ class DeepEPBuffer:
|
||||
if deepep_mode.enable_normal():
|
||||
hidden_bytes = hidden_size * param_bytes
|
||||
for config in (
|
||||
Buffer.get_dispatch_config(group.size()),
|
||||
Buffer.get_combine_config(group.size()),
|
||||
_DeepEPConfig.get_instance().normal_dispatch_config
|
||||
or Buffer.get_dispatch_config(group.size()),
|
||||
_DeepEPConfig.get_instance().normal_combine_config
|
||||
or Buffer.get_combine_config(group.size()),
|
||||
):
|
||||
num_nvl_bytes = max(
|
||||
config.get_nvl_buffer_size_hint(hidden_bytes, group.size()),
|
||||
@@ -113,6 +119,28 @@ class DeepEPBuffer:
|
||||
cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
|
||||
|
||||
|
||||
class _DeepEPConfig:
|
||||
_instance = None
|
||||
|
||||
def __init__(self):
|
||||
config_str = global_server_args_dict["deepep_config"]
|
||||
if config_str:
|
||||
config_parsed = load_json_config(config_str)
|
||||
if torch.distributed.get_rank() == 0:
|
||||
logger.info(f"Use DeepEP Config: {config_parsed}")
|
||||
self.normal_dispatch_config = Config(**config_parsed["normal_dispatch"])
|
||||
self.normal_combine_config = Config(**config_parsed["normal_combine"])
|
||||
else:
|
||||
self.normal_dispatch_config = None
|
||||
self.normal_combine_config = None
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = _DeepEPConfig()
|
||||
return cls._instance
|
||||
|
||||
|
||||
class _DeepEPDispatcherImplBase:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -295,6 +323,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
async_finish=self.async_finish,
|
||||
allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
|
||||
expert_alignment=128 if _ENABLE_JIT_DEEPGEMM else 1,
|
||||
config=_DeepEPConfig.get_instance().normal_dispatch_config,
|
||||
)
|
||||
|
||||
return (
|
||||
@@ -394,6 +423,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
async_finish=self.async_finish,
|
||||
previous_event=previous_event,
|
||||
allocate_on_comm_stream=previous_event is not None,
|
||||
config=_DeepEPConfig.get_instance().normal_combine_config,
|
||||
)
|
||||
return combined_x, event
|
||||
|
||||
|
||||
@@ -77,6 +77,7 @@ global_server_args_dict = {
|
||||
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
||||
"enable_dp_lm_head": ServerArgs.enable_dp_lm_head,
|
||||
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
||||
"deepep_config": ServerArgs.deepep_config,
|
||||
"enable_nan_detection": ServerArgs.enable_nan_detection,
|
||||
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
|
||||
"max_micro_batch_size": ServerArgs.max_micro_batch_size,
|
||||
|
||||
@@ -165,6 +165,7 @@ class ModelRunner:
|
||||
"enable_dp_attention": server_args.enable_dp_attention,
|
||||
"enable_ep_moe": server_args.enable_ep_moe,
|
||||
"enable_deepep_moe": server_args.enable_deepep_moe,
|
||||
"deepep_config": server_args.deepep_config,
|
||||
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
||||
"moe_dense_tp_size": server_args.moe_dense_tp_size,
|
||||
"n_share_experts_fusion": server_args.n_share_experts_fusion,
|
||||
|
||||
@@ -169,6 +169,7 @@ class ServerArgs:
|
||||
enable_ep_moe: bool = False
|
||||
enable_deepep_moe: bool = False
|
||||
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
|
||||
deepep_config: Optional[str] = None
|
||||
enable_torch_compile: bool = False
|
||||
torch_compile_max_bs: int = 32
|
||||
cuda_graph_max_bs: Optional[int] = None
|
||||
@@ -1249,6 +1250,12 @@ class ServerArgs:
|
||||
default="auto",
|
||||
help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--deepep-config",
|
||||
type=str,
|
||||
default=ServerArgs.deepep_config,
|
||||
help="Tuned DeepEP config suitable for your own cluster.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--n-share-experts-fusion",
|
||||
|
||||
@@ -2102,5 +2102,12 @@ def log_info_on_rank0(logger, msg):
|
||||
logger.info(msg)
|
||||
|
||||
|
||||
def load_json_config(data: str):
|
||||
try:
|
||||
return json.loads(data)
|
||||
except JSONDecodeError:
|
||||
return json.loads(Path(data).read_text())
|
||||
|
||||
|
||||
def dispose_tensor(x: torch.Tensor):
|
||||
x.set_(torch.empty((0,), device=x.device, dtype=x.dtype))
|
||||
|
||||
Reference in New Issue
Block a user