Performing Vocabulary Parallelism for LM Head across Attention TP Groups (#5558)

Co-authored-by: liusy58 <liusy58@linux.alibaba.com>
This commit is contained in:
Cheng Wan
2025-05-12 02:36:29 -04:00
committed by GitHub
parent 9f2c9568f0
commit 25c83fff6a
8 changed files with 71 additions and 23 deletions

View File

@@ -252,12 +252,12 @@ def dp_scatter(
)
def tp_reduce_scatter(
def attn_tp_reduce_scatter(
output: torch.Tensor,
input_list: List[torch.Tensor],
):
return get_attention_tp_group().reduce_scatter(output, input_list)
def tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
return get_attention_tp_group().all_gather(input_, tensor_list=output_list)

View File

@@ -23,15 +23,16 @@ import triton.language as tl
from torch import nn
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
)
from sglang.srt.layers.dp_attention import (
attn_tp_all_gather,
dp_gather_replicate,
dp_scatter,
get_attention_dp_rank,
get_attention_dp_size,
get_attention_tp_size,
)
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.managers.schedule_batch import global_server_args_dict
@@ -198,12 +199,20 @@ class LogitsProcessor(nn.Module):
super().__init__()
self.config = config
self.logit_scale = logit_scale
self.do_tensor_parallel_all_gather = (
not skip_all_gather and get_tensor_model_parallel_world_size() > 1
)
self.do_tensor_parallel_all_gather_dp_attn = (
self.do_tensor_parallel_all_gather and get_attention_dp_size() != 1
)
self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"]
if self.use_attn_tp_group:
self.attn_tp_size = get_attention_tp_size()
self.do_tensor_parallel_all_gather = (
not skip_all_gather and self.attn_tp_size > 1
)
self.do_tensor_parallel_all_gather_dp_attn = False
else:
self.do_tensor_parallel_all_gather = (
not skip_all_gather and get_tensor_model_parallel_world_size() > 1
)
self.do_tensor_parallel_all_gather_dp_attn = (
self.do_tensor_parallel_all_gather and get_attention_dp_size() != 1
)
self.final_logit_softcapping = getattr(
self.config, "final_logit_softcapping", None
)
@@ -442,7 +451,19 @@ class LogitsProcessor(nn.Module):
logits.mul_(self.logit_scale)
if self.do_tensor_parallel_all_gather:
logits = tensor_model_parallel_all_gather(logits)
if self.use_attn_tp_group:
global_logits = torch.empty(
(self.config.vocab_size, logits.shape[0]),
device=logits.device,
dtype=logits.dtype,
)
global_logits = global_logits.T
attn_tp_all_gather(
list(global_logits.tensor_split(self.attn_tp_size, dim=-1)), logits
)
logits = global_logits
else:
logits = tensor_model_parallel_all_gather(logits)
if self.do_tensor_parallel_all_gather_dp_attn:
logits, global_logits = (

View File

@@ -13,6 +13,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
from sglang.srt.layers.parameter import BasevLLMParameter
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
@@ -214,12 +215,14 @@ class VocabParallelEmbedding(torch.nn.Module):
self,
num_embeddings: int,
embedding_dim: int,
*,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_tp: bool = True,
use_attn_tp_group: bool = False,
use_presharded_weights: bool = False,
):
super().__init__()
@@ -227,9 +230,14 @@ class VocabParallelEmbedding(torch.nn.Module):
self.enable_tp = enable_tp
if self.enable_tp:
tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
if use_attn_tp_group:
tp_rank = get_attention_tp_rank()
self.tp_size = get_attention_tp_size()
else:
tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
else:
assert use_attn_tp_group is False
tp_rank = 0
self.tp_size = 1
@@ -519,22 +527,25 @@ class ParallelLMHead(VocabParallelEmbedding):
self,
num_embeddings: int,
embedding_dim: int,
*,
bias: bool = False,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_attn_tp_group: bool = False,
use_presharded_weights: bool = False,
):
super().__init__(
num_embeddings,
embedding_dim,
params_dtype,
org_num_embeddings,
padding_size,
quant_config,
prefix,
params_dtype=params_dtype,
org_num_embeddings=org_num_embeddings,
padding_size=padding_size,
quant_config=quant_config,
prefix=prefix,
use_attn_tp_group=use_attn_tp_group,
use_presharded_weights=use_presharded_weights,
)
self.quant_config = quant_config

View File

@@ -74,6 +74,7 @@ global_server_args_dict = {
"disable_radix_cache": ServerArgs.disable_radix_cache,
"enable_deepep_moe": ServerArgs.enable_deepep_moe,
"enable_dp_attention": ServerArgs.enable_dp_attention,
"enable_dp_lm_head": ServerArgs.enable_dp_lm_head,
"enable_ep_moe": ServerArgs.enable_ep_moe,
"enable_nan_detection": ServerArgs.enable_nan_detection,
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,

View File

@@ -36,13 +36,13 @@ from sglang.srt.distributed import (
)
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.dp_attention import (
attn_tp_all_gather,
attn_tp_reduce_scatter,
dp_gather_partial,
dp_scatter,
get_attention_dp_size,
get_attention_tp_rank,
get_attention_tp_size,
tp_all_gather,
tp_reduce_scatter,
)
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
@@ -1323,7 +1323,7 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
tp_all_gather(
attn_tp_all_gather(
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
)
@@ -1339,7 +1339,7 @@ class DeepseekV2DecoderLayer(nn.Module):
if self.input_is_scattered:
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
hidden_states = tensor_list[self.attn_tp_rank]
tp_reduce_scatter(hidden_states, tensor_list)
attn_tp_reduce_scatter(hidden_states, tensor_list)
if hidden_states.shape[0] != 0:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
@@ -1349,7 +1349,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states += residual
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
hidden_states = tensor_list[self.attn_tp_rank]
tp_reduce_scatter(hidden_states, tensor_list)
attn_tp_reduce_scatter(hidden_states, tensor_list)
residual = hidden_states
if hidden_states.shape[0] != 0:
hidden_states = self.post_attention_layernorm(hidden_states)
@@ -1373,7 +1373,7 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
tp_all_gather(
attn_tp_all_gather(
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
)
@@ -1475,6 +1475,7 @@ class DeepseekV2ForCausalLM(nn.Module):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
self.dp_size = get_attention_dp_size()

View File

@@ -45,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
@@ -420,6 +421,7 @@ class LlamaForCausalLM(nn.Module):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)

View File

@@ -159,6 +159,7 @@ class ServerArgs:
disable_overlap_schedule: bool = False
enable_mixed_chunk: bool = False
enable_dp_attention: bool = False
enable_dp_lm_head: bool = False
enable_ep_moe: bool = False
enable_deepep_moe: bool = False
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
@@ -323,6 +324,11 @@ class ServerArgs:
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
)
if self.enable_dp_lm_head:
assert (
self.enable_dp_attention
), "Please enable dp attention when setting enable_dp_attention. "
# DeepEP MoE
self.enable_sp_layernorm = False
if self.enable_deepep_moe:
@@ -1055,6 +1061,11 @@ class ServerArgs:
action="store_true",
help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.",
)
parser.add_argument(
"--enable-dp-lm-head",
action="store_true",
help="Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention.",
)
parser.add_argument(
"--enable-ep-moe",
action="store_true",