From 7f19e083c1d22ddd4ed7784c24326fac1c67aeea Mon Sep 17 00:00:00 2001 From: tarinkk <129432511+tarinkk@users.noreply.github.com> Date: Thu, 27 Mar 2025 20:09:35 -0400 Subject: [PATCH] Support (1 <= dp < tp) in the dp attention in DeepEP (#4770) Co-authored-by: Cheng Wan --- docs/backend/server_arguments.md | 4 +- .../device_communicators/custom_all_reduce.py | 2 +- .../sglang/srt/distributed/parallel_state.py | 23 ++- python/sglang/srt/layers/dp_attention.py | 13 +- python/sglang/srt/managers/scheduler.py | 2 +- .../srt/model_executor/cuda_graph_runner.py | 13 +- .../sglang/srt/model_executor/model_runner.py | 3 - python/sglang/srt/models/deepseek_v2.py | 171 +++++++++++++++--- python/sglang/srt/server_args.py | 17 +- test/srt/test_moe_deepep.py | 37 +++- 10 files changed, 238 insertions(+), 47 deletions(-) diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 2027c082f..3d2aae8f2 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -90,7 +90,7 @@ Please consult the documentation below to learn more about the parameters you ma ### Expert parallelism * `enable_ep_moe`: Enables expert parallelism that distributes the experts onto multiple GPUs for MoE models. * `ep_size`: The size of EP. Please shard the model weights with `tp_size=ep_size`, for detailed benchmarking refer to [this PR](https://github.com/sgl-project/sglang/pull/2203). If not set, `ep_size` will be automatically set to `tp_size`. -* `enable_deepep_moe`: Enables expert parallelism that distributes the experts onto multiple GPUs for DeepSeek-V3 model based on deepseek-ai/DeepEP. Currently DeepEP is bind to DP Attention. Please set `--enable-dp-attention --enable-deepep-moe`, perfer `tp_size=dp_size=ep_size`. +* `enable_deepep_moe`: Enables expert parallelism that distributes the experts onto multiple GPUs for DeepSeek-V3 model based on deepseek-ai/DeepEP. ## Memory and scheduling @@ -184,7 +184,7 @@ Please consult the documentation below to learn more about the parameters you ma *Note: Some of these options are still in experimental stage.* * `enable_mixed_chunk`: Enables mixing prefill and decode, see [this discussion](https://github.com/sgl-project/sglang/discussions/1163). -* `enable_dp_attention`: Enable [Data Parallelism Attention](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models) for Deepseek models. Note that you need to choose `dp_size = tp_size` for this. +* `enable_dp_attention`: Enable [Data Parallelism Attention](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models) for Deepseek models. * `enable_torch_compile`: Torch compile the model. Note that compiling a model takes a long time but have a great performance boost. The compiled model can also be [cached for future use](https://docs.sglang.ai/backend/hyperparameter_tuning.html#enabling-cache-for-torch-compile). * `torch_compile_max_bs`: The maximum batch size when using `torch_compile`. * `cuda_graph_max_bs`: Adjust the maximum batchsize when using cuda graph. By default this is chosen for you based on GPU specifics. diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py index 2d5f9ada4..813ae0122 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py @@ -5,7 +5,7 @@ import logging import os from contextlib import contextmanager from functools import wraps -from typing import Callable, List, Optional, TypeVar, Union +from typing import Any, Callable, List, Optional, TypeVar, Union import torch import torch.distributed as dist diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index ced337205..d4000b866 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -439,6 +439,15 @@ class GroupCoordinator: else: torch.distributed.all_reduce(input_, group=self.device_group) + def reduce_scatter( + self, + output: torch.Tensor, + input_list: List[torch.Tensor], + ) -> None: + # TODO(ch-wan): support other backends + torch.distributed.reduce_scatter(output, input_list, group=self.device_group) + return output + def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor): pynccl_comm = self.pynccl_comm if pynccl_comm is not None and not pynccl_comm.disabled: @@ -456,11 +465,23 @@ class GroupCoordinator: output, input, group_name=self.unique_name ) - def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + def all_gather( + self, + input_: torch.Tensor, + dim: int = -1, + tensor_list: List[torch.Tensor] = None, + ) -> torch.Tensor: world_size = self.world_size # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ + + if tensor_list is not None: + # TODO(ch-wan): support other backends + return torch.distributed.all_gather( + tensor_list, input_, group=self.device_group + ) + assert ( -input_.dim() <= dim < input_.dim() ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index c36b9706e..bf6064119 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -3,7 +3,7 @@ from __future__ import annotations import functools import logging from contextlib import contextmanager -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, List import torch import triton @@ -249,3 +249,14 @@ def dp_scatter( memcpy_triton( local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True ) + + +def tp_reduce_scatter( + output: torch.Tensor, + input_list: List[torch.Tensor], +): + return get_attention_tp_group().reduce_scatter(output, input_list) + + +def tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor): + return get_attention_tp_group().all_gather(input_, tensor_list=output_list) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 7b5dc4520..70ce381be 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1186,7 +1186,7 @@ class Scheduler( ret = None # Handle DP attention - if self.server_args.enable_dp_attention: + if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm: ret, _ = self.prepare_dp_attn_batch(ret) return ret diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 449113c70..f5ac35d40 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -174,6 +174,7 @@ class CudaGraphRunner: self.disable_padding = model_runner.server_args.disable_cuda_graph_padding self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder self.enable_dp_attention = model_runner.server_args.enable_dp_attention + self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm self.speculative_algorithm = model_runner.server_args.speculative_algorithm self.tp_size = model_runner.server_args.tp_size self.dp_size = model_runner.server_args.dp_size @@ -245,8 +246,8 @@ class CudaGraphRunner: ) else: self.encoder_lens = None - - if self.enable_dp_attention: + if self.enable_dp_attention or self.enable_sp_layernorm: + # TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer self.gathered_buffer = torch.zeros( ( self.max_bs * self.dp_size * self.num_tokens_per_bs, @@ -288,7 +289,7 @@ class CudaGraphRunner: self.model_runner.token_to_kv_pool.capture_mode = False def can_run(self, forward_batch: ForwardBatch): - if self.enable_dp_attention: + if self.enable_dp_attention or self.enable_sp_layernorm: total_global_tokens = sum(forward_batch.global_num_tokens_cpu) is_bs_supported = forward_batch.can_run_dp_cuda_graph and ( @@ -369,7 +370,7 @@ class CudaGraphRunner: encoder_lens = None mrope_positions = self.mrope_positions[:, :bs] - if self.enable_dp_attention: + if self.enable_dp_attention or self.enable_sp_layernorm: self.global_num_tokens_gpu.copy_( torch.tensor( [ @@ -471,7 +472,7 @@ class CudaGraphRunner: raw_num_token = raw_bs * self.num_tokens_per_bs # Pad - if self.enable_dp_attention: + if self.enable_dp_attention or self.enable_sp_layernorm: index = bisect.bisect_left( self.capture_bs, sum(forward_batch.global_num_tokens_cpu) ) @@ -497,7 +498,7 @@ class CudaGraphRunner: self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) if forward_batch.mrope_positions is not None: self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions) - if self.enable_dp_attention: + if self.enable_dp_attention or self.enable_sp_layernorm: self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu) if hasattr(forward_batch.spec_info, "hidden_states"): diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 716b61e22..8e0217277 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -281,9 +281,6 @@ class ModelRunner: if server_args.enable_deepep_moe: logger.info("DeepEP is turned on.") - assert ( - server_args.enable_dp_attention == True - ), "Currently DeepEP is bind to Attention DP. Set '--enable-dp-attention --enable-deepep-moe'" def init_torch_distributed(self): logger.info("Init torch distributed begin.") diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 8ff4c4373..4b733a67c 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -39,6 +39,8 @@ from sglang.srt.layers.dp_attention import ( get_attention_dp_size, get_attention_tp_rank, get_attention_tp_size, + tp_all_gather, + tp_reduce_scatter, ) from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -278,7 +280,11 @@ class DeepseekV2MoE(nn.Module): topk_weights = torch.empty( (0, self.top_k), dtype=torch.float32, device=hidden_states.device ) - if forward_mode is not None and not forward_mode.is_idle(): + if ( + forward_mode is not None + and not forward_mode.is_idle() + and hidden_states.shape[0] > 0 + ): # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) if self.n_shared_experts is not None: @@ -969,6 +975,14 @@ class DeepseekV2DecoderLayer(nn.Module): is_nextn: bool = False, prefix: str = "", ) -> None: + + def is_sparse_layer(l: int): + return ( + config.n_routed_experts is not None + and l >= config.first_k_dense_replace + and l % config.moe_layer_freq == 0 + ) + super().__init__() self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) @@ -977,6 +991,8 @@ class DeepseekV2DecoderLayer(nn.Module): self.enable_dp_attention = global_server_args_dict["enable_dp_attention"] self.layer_id = layer_id self.dp_size = get_attention_dp_size() + self.attn_tp_size = get_attention_tp_size() + self.attn_tp_rank = get_attention_tp_rank() if not global_server_args_dict["disable_mla"]: self.self_attn = DeepseekV2AttentionMLA( @@ -1019,16 +1035,13 @@ class DeepseekV2DecoderLayer(nn.Module): prefix=add_prefix("self_attn", prefix), ) - if is_nextn or ( - config.n_routed_experts is not None - and layer_id >= config.first_k_dense_replace - and layer_id % config.moe_layer_freq == 0 - ): + if is_nextn or is_sparse_layer(layer_id): self.mlp = DeepseekV2MoE( config=config, quant_config=quant_config, prefix=add_prefix("mlp", prefix), ) + self.is_sparse = True else: self.mlp = DeepseekV2MLP( hidden_size=config.hidden_size, @@ -1037,6 +1050,14 @@ class DeepseekV2DecoderLayer(nn.Module): quant_config=quant_config, prefix=add_prefix("mlp", prefix), ) + self.is_sparse = False + + self.input_is_scattered = ( + is_sparse_layer(layer_id - 1) + and global_server_args_dict["enable_deepep_moe"] + ) + self.is_last_layer = self.layer_id == config.num_hidden_layers - 1 + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps @@ -1049,6 +1070,23 @@ class DeepseekV2DecoderLayer(nn.Module): forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> torch.Tensor: + if global_server_args_dict["enable_deepep_moe"] and self.is_sparse: + return self.forward_deepep( + positions, hidden_states, forward_batch, residual + ) + else: + return self.forward_normal( + positions, hidden_states, forward_batch, residual + ) + + def forward_normal( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + if hidden_states.shape[0] == 0: residual = hidden_states else: @@ -1065,29 +1103,35 @@ class DeepseekV2DecoderLayer(nn.Module): forward_batch=forward_batch, ) + if self.attn_tp_size != 1 and self.input_is_scattered: + hidden_states, local_hidden_states = ( + forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], + hidden_states, + ) + tp_all_gather( + list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states + ) + residual, local_residual = ( + forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], + residual, + ) + tp_all_gather( + list(residual.tensor_split(self.attn_tp_size)), local_residual + ) + # Gather if get_tensor_model_parallel_world_size() > 1: # all gather and all reduce if self.dp_size != 1: - if global_server_args_dict["enable_deepep_moe"] and isinstance( - self.mlp, DeepseekV2MoE - ): - if hidden_states.shape[0] != 0: - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual - ) - hidden_states = self.mlp(hidden_states, forward_batch.forward_mode) - return hidden_states, residual - else: - if get_attention_tp_rank() == 0: - hidden_states += residual - hidden_states, local_hidden_states = ( - forward_batch.gathered_buffer, - hidden_states, - ) - dp_gather_partial(hidden_states, local_hidden_states, forward_batch) - dp_scatter(residual, hidden_states, forward_batch) - hidden_states = self.post_attention_layernorm(hidden_states) + if self.attn_tp_rank == 0: + hidden_states += residual + hidden_states, local_hidden_states = ( + forward_batch.gathered_buffer, + hidden_states, + ) + dp_gather_partial(hidden_states, local_hidden_states, forward_batch) + dp_scatter(residual, hidden_states, forward_batch) + hidden_states = self.post_attention_layernorm(hidden_states) else: hidden_states = tensor_model_parallel_all_reduce(hidden_states) hidden_states, residual = self.post_attention_layernorm( @@ -1101,6 +1145,7 @@ class DeepseekV2DecoderLayer(nn.Module): # Fully Connected hidden_states = self.mlp(hidden_states) + # TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter # Scatter if self.dp_size != 1: # important: forward batch.gathered_buffer is used both after scatter and after gather. @@ -1113,6 +1158,82 @@ class DeepseekV2DecoderLayer(nn.Module): return hidden_states, residual + def forward_deepep( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + + if hidden_states.shape[0] == 0: + residual = hidden_states + else: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + if self.attn_tp_size != 1 and self.input_is_scattered: + hidden_states, local_hidden_states = ( + forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], + hidden_states, + ) + tp_all_gather( + list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states + ) + + # Self Attention + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + + if self.attn_tp_size != 1: + if self.input_is_scattered: + tensor_list = list(hidden_states.tensor_split(self.attn_tp_size)) + hidden_states = tensor_list[self.attn_tp_rank] + tp_reduce_scatter(hidden_states, tensor_list) + if hidden_states.shape[0] != 0: + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual + ) + else: + if self.attn_tp_rank == 0: + hidden_states += residual + tensor_list = list(hidden_states.tensor_split(self.attn_tp_size)) + hidden_states = tensor_list[self.attn_tp_rank] + tp_reduce_scatter(hidden_states, tensor_list) + residual = hidden_states + if hidden_states.shape[0] != 0: + hidden_states = self.post_attention_layernorm(hidden_states) + else: + if hidden_states.shape[0] != 0: + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual + ) + hidden_states = self.mlp(hidden_states, forward_batch.forward_mode) + + if self.is_last_layer and self.attn_tp_size != 1: + hidden_states, local_hidden_states = ( + forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], + hidden_states, + ) + tp_all_gather( + list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states + ) + residual, local_residual = ( + forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], + residual, + ) + tp_all_gather( + list(residual.tensor_split(self.attn_tp_size)), local_residual + ) + + return hidden_states, residual + class DeepseekV2Model(nn.Module): diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 59b6e73d4..6a0166b41 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -290,12 +290,17 @@ class ServerArgs: logger.warning( f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. " ) - # DeepEP MoE - if self.enable_deepep_moe: - self.ep_size = self.dp_size - logger.info( - f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the data parallel size[{self.dp_size}]." - ) + + self.enable_sp_layernorm = False + # DeepEP MoE + if self.enable_deepep_moe: + self.ep_size = self.tp_size + self.enable_sp_layernorm = ( + self.dp_size < self.tp_size if self.enable_dp_attention else True + ) + logger.info( + f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." + ) # Speculative Decoding if self.speculative_algorithm == "NEXTN": diff --git a/test/srt/test_moe_deepep.py b/test/srt/test_moe_deepep.py index 9c4194823..a25146eb5 100644 --- a/test/srt/test_moe_deepep.py +++ b/test/srt/test_moe_deepep.py @@ -12,7 +12,42 @@ from sglang.test.test_utils import ( ) -class TestDeepEPMoE(CustomTestCase): +class TestPureTP(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "2", + "--enable-deepep-moe", + "--disable-cuda-graph", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + self.assertGreater(metrics["score"], 0.5) + + +class TestDPAttn(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST