Fix two issues related to --moe-dense-tp-size=1 (#5657)
Co-authored-by: liusy58 <liusy58@linux.alibaba.com> Co-authored-by: 颉沆 <xiehang.lsy@alibaba-inc.com>
This commit is contained in:
@@ -40,9 +40,9 @@ from sglang.srt.layers.dp_attention import (
|
||||
attn_tp_reduce_scatter,
|
||||
dp_gather_partial,
|
||||
dp_scatter,
|
||||
get_attention_dp_size,
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
get_local_attention_dp_size,
|
||||
)
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
@@ -438,7 +438,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.v_head_dim = v_head_dim
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.dp_size = get_attention_dp_size()
|
||||
attn_tp_rank = get_attention_tp_rank()
|
||||
attn_tp_size = get_attention_tp_size()
|
||||
|
||||
@@ -1133,7 +1132,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
|
||||
self.layer_id = layer_id
|
||||
self.dp_size = get_attention_dp_size()
|
||||
self.local_dp_size = get_local_attention_dp_size()
|
||||
self.attn_tp_size = get_attention_tp_size()
|
||||
self.attn_tp_rank = get_attention_tp_rank()
|
||||
self.self_attn = DeepseekV2AttentionMLA(
|
||||
@@ -1184,7 +1183,8 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
)
|
||||
|
||||
self.input_is_scattered = (
|
||||
previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
|
||||
layer_id > 0
|
||||
and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
|
||||
)
|
||||
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
|
||||
|
||||
@@ -1264,7 +1264,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
# Gather
|
||||
if get_tensor_model_parallel_world_size() > 1:
|
||||
# all gather and all reduce
|
||||
if self.dp_size != 1:
|
||||
if self.local_dp_size != 1:
|
||||
if self.attn_tp_rank == 0:
|
||||
hidden_states += residual
|
||||
hidden_states, local_hidden_states = (
|
||||
@@ -1289,7 +1289,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
|
||||
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
|
||||
# Scatter
|
||||
if self.dp_size != 1:
|
||||
if self.local_dp_size != 1:
|
||||
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
||||
# be careful about this!
|
||||
hidden_states, global_hidden_states = (
|
||||
@@ -1413,7 +1413,7 @@ class DeepseekV2Model(nn.Module):
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.dp_size = get_attention_dp_size()
|
||||
self.dp_size = get_local_attention_dp_size()
|
||||
|
||||
def get_input_embeddings(self) -> torch.Tensor:
|
||||
return self.embed_tokens
|
||||
@@ -1478,7 +1478,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.dp_size = get_attention_dp_size()
|
||||
self.dp_size = get_local_attention_dp_size()
|
||||
|
||||
def determine_n_share_experts_fusion(
|
||||
self, architecture: str = "DeepseekV3ForCausalLM"
|
||||
|
||||
Reference in New Issue
Block a user