Performing Vocabulary Parallelism for LM Head across Attention TP Groups (#5558)

Co-authored-by: liusy58 <liusy58@linux.alibaba.com>
This commit is contained in:
Cheng Wan
2025-05-12 02:36:29 -04:00
committed by GitHub
parent 9f2c9568f0
commit 25c83fff6a
8 changed files with 71 additions and 23 deletions

View File

@@ -36,13 +36,13 @@ from sglang.srt.distributed import (
)
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.dp_attention import (
attn_tp_all_gather,
attn_tp_reduce_scatter,
dp_gather_partial,
dp_scatter,
get_attention_dp_size,
get_attention_tp_rank,
get_attention_tp_size,
tp_all_gather,
tp_reduce_scatter,
)
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
@@ -1323,7 +1323,7 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
tp_all_gather(
attn_tp_all_gather(
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
)
@@ -1339,7 +1339,7 @@ class DeepseekV2DecoderLayer(nn.Module):
if self.input_is_scattered:
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
hidden_states = tensor_list[self.attn_tp_rank]
tp_reduce_scatter(hidden_states, tensor_list)
attn_tp_reduce_scatter(hidden_states, tensor_list)
if hidden_states.shape[0] != 0:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
@@ -1349,7 +1349,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states += residual
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
hidden_states = tensor_list[self.attn_tp_rank]
tp_reduce_scatter(hidden_states, tensor_list)
attn_tp_reduce_scatter(hidden_states, tensor_list)
residual = hidden_states
if hidden_states.shape[0] != 0:
hidden_states = self.post_attention_layernorm(hidden_states)
@@ -1373,7 +1373,7 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
tp_all_gather(
attn_tp_all_gather(
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
)
@@ -1475,6 +1475,7 @@ class DeepseekV2ForCausalLM(nn.Module):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
self.dp_size = get_attention_dp_size()