diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 36c44b57e..50b888cbf 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -221,3 +221,4 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `triton_attention_num_kv_splits` | Use to adjust the number of KV splits in triton kernels. | `8` | | `flashinfer_mla_disable_ragged` | Disable the use of the [ragged prefill](https://github.com/flashinfer-ai/flashinfer/blob/5751fc68f109877f6e0fc54f674cdcdef361af56/docs/tutorials/kv_layout.rst#L26) wrapper for the FlashInfer MLA attention backend. Ragged prefill increases throughput by computing MHA instead of paged MLA when there is no prefix match. Only use it when FlashInfer is being used as the MLA backend. | `False` | | `disable_chunked_prefix_cache` | Disable the use of chunked prefix cache for DeepSeek models. Only use it when FA3 is attention backend. | `False` | +| `enable_dp_lm_head` | Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention. | `False` | diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index 2cc399ab7..0f1e453bf 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -252,12 +252,12 @@ def dp_scatter( ) -def tp_reduce_scatter( +def attn_tp_reduce_scatter( output: torch.Tensor, input_list: List[torch.Tensor], ): return get_attention_tp_group().reduce_scatter(output, input_list) -def tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor): +def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor): return get_attention_tp_group().all_gather(input_, tensor_list=output_list) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 4958c6d04..5a4f07817 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -23,15 +23,16 @@ import triton.language as tl from torch import nn from sglang.srt.distributed import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, ) from sglang.srt.layers.dp_attention import ( + attn_tp_all_gather, dp_gather_replicate, dp_scatter, get_attention_dp_rank, get_attention_dp_size, + get_attention_tp_size, ) from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.managers.schedule_batch import global_server_args_dict @@ -198,12 +199,20 @@ class LogitsProcessor(nn.Module): super().__init__() self.config = config self.logit_scale = logit_scale - self.do_tensor_parallel_all_gather = ( - not skip_all_gather and get_tensor_model_parallel_world_size() > 1 - ) - self.do_tensor_parallel_all_gather_dp_attn = ( - self.do_tensor_parallel_all_gather and get_attention_dp_size() != 1 - ) + self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"] + if self.use_attn_tp_group: + self.attn_tp_size = get_attention_tp_size() + self.do_tensor_parallel_all_gather = ( + not skip_all_gather and self.attn_tp_size > 1 + ) + self.do_tensor_parallel_all_gather_dp_attn = False + else: + self.do_tensor_parallel_all_gather = ( + not skip_all_gather and get_tensor_model_parallel_world_size() > 1 + ) + self.do_tensor_parallel_all_gather_dp_attn = ( + self.do_tensor_parallel_all_gather and get_attention_dp_size() != 1 + ) self.final_logit_softcapping = getattr( self.config, "final_logit_softcapping", None ) @@ -442,7 +451,19 @@ class LogitsProcessor(nn.Module): logits.mul_(self.logit_scale) if self.do_tensor_parallel_all_gather: - logits = tensor_model_parallel_all_gather(logits) + if self.use_attn_tp_group: + global_logits = torch.empty( + (self.config.vocab_size, logits.shape[0]), + device=logits.device, + dtype=logits.dtype, + ) + global_logits = global_logits.T + attn_tp_all_gather( + list(global_logits.tensor_split(self.attn_tp_size, dim=-1)), logits + ) + logits = global_logits + else: + logits = tensor_model_parallel_all_gather(logits) if self.do_tensor_parallel_all_gather_dp_attn: logits, global_logits = ( diff --git a/python/sglang/srt/layers/vocab_parallel_embedding.py b/python/sglang/srt/layers/vocab_parallel_embedding.py index ebc148feb..ec7c140ae 100644 --- a/python/sglang/srt/layers/vocab_parallel_embedding.py +++ b/python/sglang/srt/layers/vocab_parallel_embedding.py @@ -13,6 +13,7 @@ from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) +from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size from sglang.srt.layers.parameter import BasevLLMParameter from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, @@ -214,12 +215,14 @@ class VocabParallelEmbedding(torch.nn.Module): self, num_embeddings: int, embedding_dim: int, + *, params_dtype: Optional[torch.dtype] = None, org_num_embeddings: Optional[int] = None, padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", enable_tp: bool = True, + use_attn_tp_group: bool = False, use_presharded_weights: bool = False, ): super().__init__() @@ -227,9 +230,14 @@ class VocabParallelEmbedding(torch.nn.Module): self.enable_tp = enable_tp if self.enable_tp: - tp_rank = get_tensor_model_parallel_rank() - self.tp_size = get_tensor_model_parallel_world_size() + if use_attn_tp_group: + tp_rank = get_attention_tp_rank() + self.tp_size = get_attention_tp_size() + else: + tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() else: + assert use_attn_tp_group is False tp_rank = 0 self.tp_size = 1 @@ -519,22 +527,25 @@ class ParallelLMHead(VocabParallelEmbedding): self, num_embeddings: int, embedding_dim: int, + *, bias: bool = False, params_dtype: Optional[torch.dtype] = None, org_num_embeddings: Optional[int] = None, padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_attn_tp_group: bool = False, use_presharded_weights: bool = False, ): super().__init__( num_embeddings, embedding_dim, - params_dtype, - org_num_embeddings, - padding_size, - quant_config, - prefix, + params_dtype=params_dtype, + org_num_embeddings=org_num_embeddings, + padding_size=padding_size, + quant_config=quant_config, + prefix=prefix, + use_attn_tp_group=use_attn_tp_group, use_presharded_weights=use_presharded_weights, ) self.quant_config = quant_config diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index a797a7f3a..ac4b4edcb 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -74,6 +74,7 @@ global_server_args_dict = { "disable_radix_cache": ServerArgs.disable_radix_cache, "enable_deepep_moe": ServerArgs.enable_deepep_moe, "enable_dp_attention": ServerArgs.enable_dp_attention, + "enable_dp_lm_head": ServerArgs.enable_dp_lm_head, "enable_ep_moe": ServerArgs.enable_ep_moe, "enable_nan_detection": ServerArgs.enable_nan_detection, "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 9770de1c3..e8ef96a6e 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -36,13 +36,13 @@ from sglang.srt.distributed import ( ) from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.dp_attention import ( + attn_tp_all_gather, + attn_tp_reduce_scatter, dp_gather_partial, dp_scatter, get_attention_dp_size, get_attention_tp_rank, get_attention_tp_size, - tp_all_gather, - tp_reduce_scatter, ) from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -1323,7 +1323,7 @@ class DeepseekV2DecoderLayer(nn.Module): forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], hidden_states, ) - tp_all_gather( + attn_tp_all_gather( list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states ) @@ -1339,7 +1339,7 @@ class DeepseekV2DecoderLayer(nn.Module): if self.input_is_scattered: tensor_list = list(hidden_states.tensor_split(self.attn_tp_size)) hidden_states = tensor_list[self.attn_tp_rank] - tp_reduce_scatter(hidden_states, tensor_list) + attn_tp_reduce_scatter(hidden_states, tensor_list) if hidden_states.shape[0] != 0: hidden_states, residual = self.post_attention_layernorm( hidden_states, residual @@ -1349,7 +1349,7 @@ class DeepseekV2DecoderLayer(nn.Module): hidden_states += residual tensor_list = list(hidden_states.tensor_split(self.attn_tp_size)) hidden_states = tensor_list[self.attn_tp_rank] - tp_reduce_scatter(hidden_states, tensor_list) + attn_tp_reduce_scatter(hidden_states, tensor_list) residual = hidden_states if hidden_states.shape[0] != 0: hidden_states = self.post_attention_layernorm(hidden_states) @@ -1373,7 +1373,7 @@ class DeepseekV2DecoderLayer(nn.Module): forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], hidden_states, ) - tp_all_gather( + attn_tp_all_gather( list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states ) @@ -1475,6 +1475,7 @@ class DeepseekV2ForCausalLM(nn.Module): config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), + use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], ) self.logits_processor = LogitsProcessor(config) self.dp_size = get_attention_dp_size() diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index ab884ad9d..dc4d8f9df 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -45,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import ( default_weight_loader, @@ -420,6 +421,7 @@ class LlamaForCausalLM(nn.Module): config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), + use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], ) self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 2e3b6c4df..a780976e3 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -159,6 +159,7 @@ class ServerArgs: disable_overlap_schedule: bool = False enable_mixed_chunk: bool = False enable_dp_attention: bool = False + enable_dp_lm_head: bool = False enable_ep_moe: bool = False enable_deepep_moe: bool = False deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto" @@ -323,6 +324,11 @@ class ServerArgs: f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. " ) + if self.enable_dp_lm_head: + assert ( + self.enable_dp_attention + ), "Please enable dp attention when setting enable_dp_attention. " + # DeepEP MoE self.enable_sp_layernorm = False if self.enable_deepep_moe: @@ -1055,6 +1061,11 @@ class ServerArgs: action="store_true", help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.", ) + parser.add_argument( + "--enable-dp-lm-head", + action="store_true", + help="Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention.", + ) parser.add_argument( "--enable-ep-moe", action="store_true",