Performing Vocabulary Parallelism for LM Head across Attention TP Groups (#5558)
Co-authored-by: liusy58 <liusy58@linux.alibaba.com>
This commit is contained in:
@@ -36,13 +36,13 @@ from sglang.srt.distributed import (
|
||||
)
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
attn_tp_all_gather,
|
||||
attn_tp_reduce_scatter,
|
||||
dp_gather_partial,
|
||||
dp_scatter,
|
||||
get_attention_dp_size,
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
tp_all_gather,
|
||||
tp_reduce_scatter,
|
||||
)
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
@@ -1323,7 +1323,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
tp_all_gather(
|
||||
attn_tp_all_gather(
|
||||
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
||||
)
|
||||
|
||||
@@ -1339,7 +1339,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
if self.input_is_scattered:
|
||||
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
||||
hidden_states = tensor_list[self.attn_tp_rank]
|
||||
tp_reduce_scatter(hidden_states, tensor_list)
|
||||
attn_tp_reduce_scatter(hidden_states, tensor_list)
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
@@ -1349,7 +1349,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
hidden_states += residual
|
||||
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
||||
hidden_states = tensor_list[self.attn_tp_rank]
|
||||
tp_reduce_scatter(hidden_states, tensor_list)
|
||||
attn_tp_reduce_scatter(hidden_states, tensor_list)
|
||||
residual = hidden_states
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
@@ -1373,7 +1373,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
tp_all_gather(
|
||||
attn_tp_all_gather(
|
||||
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
||||
)
|
||||
|
||||
@@ -1475,6 +1475,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.dp_size = get_attention_dp_size()
|
||||
|
||||
Reference in New Issue
Block a user