Performing Vocabulary Parallelism for LM Head across Attention TP Groups (#5558)

Co-authored-by: liusy58 <liusy58@linux.alibaba.com>
This commit is contained in:
Cheng Wan
2025-05-12 02:36:29 -04:00
committed by GitHub
parent 9f2c9568f0
commit 25c83fff6a
8 changed files with 71 additions and 23 deletions

View File

@@ -221,3 +221,4 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `triton_attention_num_kv_splits` | Use to adjust the number of KV splits in triton kernels. | `8` |
| `flashinfer_mla_disable_ragged` | Disable the use of the [ragged prefill](https://github.com/flashinfer-ai/flashinfer/blob/5751fc68f109877f6e0fc54f674cdcdef361af56/docs/tutorials/kv_layout.rst#L26) wrapper for the FlashInfer MLA attention backend. Ragged prefill increases throughput by computing MHA instead of paged MLA when there is no prefix match. Only use it when FlashInfer is being used as the MLA backend. | `False` |
| `disable_chunked_prefix_cache` | Disable the use of chunked prefix cache for DeepSeek models. Only use it when FA3 is attention backend. | `False` |
| `enable_dp_lm_head` | Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention. | `False` |

View File

@@ -252,12 +252,12 @@ def dp_scatter(
)
def tp_reduce_scatter(
def attn_tp_reduce_scatter(
output: torch.Tensor,
input_list: List[torch.Tensor],
):
return get_attention_tp_group().reduce_scatter(output, input_list)
def tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
return get_attention_tp_group().all_gather(input_, tensor_list=output_list)

View File

@@ -23,15 +23,16 @@ import triton.language as tl
from torch import nn
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
)
from sglang.srt.layers.dp_attention import (
attn_tp_all_gather,
dp_gather_replicate,
dp_scatter,
get_attention_dp_rank,
get_attention_dp_size,
get_attention_tp_size,
)
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.managers.schedule_batch import global_server_args_dict
@@ -198,12 +199,20 @@ class LogitsProcessor(nn.Module):
super().__init__()
self.config = config
self.logit_scale = logit_scale
self.do_tensor_parallel_all_gather = (
not skip_all_gather and get_tensor_model_parallel_world_size() > 1
)
self.do_tensor_parallel_all_gather_dp_attn = (
self.do_tensor_parallel_all_gather and get_attention_dp_size() != 1
)
self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"]
if self.use_attn_tp_group:
self.attn_tp_size = get_attention_tp_size()
self.do_tensor_parallel_all_gather = (
not skip_all_gather and self.attn_tp_size > 1
)
self.do_tensor_parallel_all_gather_dp_attn = False
else:
self.do_tensor_parallel_all_gather = (
not skip_all_gather and get_tensor_model_parallel_world_size() > 1
)
self.do_tensor_parallel_all_gather_dp_attn = (
self.do_tensor_parallel_all_gather and get_attention_dp_size() != 1
)
self.final_logit_softcapping = getattr(
self.config, "final_logit_softcapping", None
)
@@ -442,7 +451,19 @@ class LogitsProcessor(nn.Module):
logits.mul_(self.logit_scale)
if self.do_tensor_parallel_all_gather:
logits = tensor_model_parallel_all_gather(logits)
if self.use_attn_tp_group:
global_logits = torch.empty(
(self.config.vocab_size, logits.shape[0]),
device=logits.device,
dtype=logits.dtype,
)
global_logits = global_logits.T
attn_tp_all_gather(
list(global_logits.tensor_split(self.attn_tp_size, dim=-1)), logits
)
logits = global_logits
else:
logits = tensor_model_parallel_all_gather(logits)
if self.do_tensor_parallel_all_gather_dp_attn:
logits, global_logits = (

View File

@@ -13,6 +13,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
from sglang.srt.layers.parameter import BasevLLMParameter
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
@@ -214,12 +215,14 @@ class VocabParallelEmbedding(torch.nn.Module):
self,
num_embeddings: int,
embedding_dim: int,
*,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_tp: bool = True,
use_attn_tp_group: bool = False,
use_presharded_weights: bool = False,
):
super().__init__()
@@ -227,9 +230,14 @@ class VocabParallelEmbedding(torch.nn.Module):
self.enable_tp = enable_tp
if self.enable_tp:
tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
if use_attn_tp_group:
tp_rank = get_attention_tp_rank()
self.tp_size = get_attention_tp_size()
else:
tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
else:
assert use_attn_tp_group is False
tp_rank = 0
self.tp_size = 1
@@ -519,22 +527,25 @@ class ParallelLMHead(VocabParallelEmbedding):
self,
num_embeddings: int,
embedding_dim: int,
*,
bias: bool = False,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_attn_tp_group: bool = False,
use_presharded_weights: bool = False,
):
super().__init__(
num_embeddings,
embedding_dim,
params_dtype,
org_num_embeddings,
padding_size,
quant_config,
prefix,
params_dtype=params_dtype,
org_num_embeddings=org_num_embeddings,
padding_size=padding_size,
quant_config=quant_config,
prefix=prefix,
use_attn_tp_group=use_attn_tp_group,
use_presharded_weights=use_presharded_weights,
)
self.quant_config = quant_config

View File

@@ -74,6 +74,7 @@ global_server_args_dict = {
"disable_radix_cache": ServerArgs.disable_radix_cache,
"enable_deepep_moe": ServerArgs.enable_deepep_moe,
"enable_dp_attention": ServerArgs.enable_dp_attention,
"enable_dp_lm_head": ServerArgs.enable_dp_lm_head,
"enable_ep_moe": ServerArgs.enable_ep_moe,
"enable_nan_detection": ServerArgs.enable_nan_detection,
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,

View File

@@ -36,13 +36,13 @@ from sglang.srt.distributed import (
)
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.dp_attention import (
attn_tp_all_gather,
attn_tp_reduce_scatter,
dp_gather_partial,
dp_scatter,
get_attention_dp_size,
get_attention_tp_rank,
get_attention_tp_size,
tp_all_gather,
tp_reduce_scatter,
)
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
@@ -1323,7 +1323,7 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
tp_all_gather(
attn_tp_all_gather(
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
)
@@ -1339,7 +1339,7 @@ class DeepseekV2DecoderLayer(nn.Module):
if self.input_is_scattered:
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
hidden_states = tensor_list[self.attn_tp_rank]
tp_reduce_scatter(hidden_states, tensor_list)
attn_tp_reduce_scatter(hidden_states, tensor_list)
if hidden_states.shape[0] != 0:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
@@ -1349,7 +1349,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states += residual
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
hidden_states = tensor_list[self.attn_tp_rank]
tp_reduce_scatter(hidden_states, tensor_list)
attn_tp_reduce_scatter(hidden_states, tensor_list)
residual = hidden_states
if hidden_states.shape[0] != 0:
hidden_states = self.post_attention_layernorm(hidden_states)
@@ -1373,7 +1373,7 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
tp_all_gather(
attn_tp_all_gather(
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
)
@@ -1475,6 +1475,7 @@ class DeepseekV2ForCausalLM(nn.Module):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
self.dp_size = get_attention_dp_size()

View File

@@ -45,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
@@ -420,6 +421,7 @@ class LlamaForCausalLM(nn.Module):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)

View File

@@ -159,6 +159,7 @@ class ServerArgs:
disable_overlap_schedule: bool = False
enable_mixed_chunk: bool = False
enable_dp_attention: bool = False
enable_dp_lm_head: bool = False
enable_ep_moe: bool = False
enable_deepep_moe: bool = False
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
@@ -323,6 +324,11 @@ class ServerArgs:
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
)
if self.enable_dp_lm_head:
assert (
self.enable_dp_attention
), "Please enable dp attention when setting enable_dp_attention. "
# DeepEP MoE
self.enable_sp_layernorm = False
if self.enable_deepep_moe:
@@ -1055,6 +1061,11 @@ class ServerArgs:
action="store_true",
help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.",
)
parser.add_argument(
"--enable-dp-lm-head",
action="store_true",
help="Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention.",
)
parser.add_argument(
"--enable-ep-moe",
action="store_true",