Fix two issues related to --moe-dense-tp-size=1 (#5657)

Co-authored-by: liusy58 <liusy58@linux.alibaba.com>
Co-authored-by: 颉沆 <xiehang.lsy@alibaba-inc.com>
This commit is contained in:
Cheng Wan
2025-05-13 02:51:39 -04:00
committed by GitHub
parent 1ab14c4c5c
commit b2e95f62b4
6 changed files with 119 additions and 45 deletions

View File

@@ -40,9 +40,9 @@ from sglang.srt.layers.dp_attention import (
attn_tp_reduce_scatter,
dp_gather_partial,
dp_scatter,
get_attention_dp_size,
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
)
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
@@ -438,7 +438,6 @@ class DeepseekV2AttentionMLA(nn.Module):
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.dp_size = get_attention_dp_size()
attn_tp_rank = get_attention_tp_rank()
attn_tp_size = get_attention_tp_size()
@@ -1133,7 +1132,7 @@ class DeepseekV2DecoderLayer(nn.Module):
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
self.layer_id = layer_id
self.dp_size = get_attention_dp_size()
self.local_dp_size = get_local_attention_dp_size()
self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank()
self.self_attn = DeepseekV2AttentionMLA(
@@ -1184,7 +1183,8 @@ class DeepseekV2DecoderLayer(nn.Module):
)
self.input_is_scattered = (
previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
layer_id > 0
and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
)
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
@@ -1264,7 +1264,7 @@ class DeepseekV2DecoderLayer(nn.Module):
# Gather
if get_tensor_model_parallel_world_size() > 1:
# all gather and all reduce
if self.dp_size != 1:
if self.local_dp_size != 1:
if self.attn_tp_rank == 0:
hidden_states += residual
hidden_states, local_hidden_states = (
@@ -1289,7 +1289,7 @@ class DeepseekV2DecoderLayer(nn.Module):
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
# Scatter
if self.dp_size != 1:
if self.local_dp_size != 1:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states, global_hidden_states = (
@@ -1413,7 +1413,7 @@ class DeepseekV2Model(nn.Module):
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.dp_size = get_attention_dp_size()
self.dp_size = get_local_attention_dp_size()
def get_input_embeddings(self) -> torch.Tensor:
return self.embed_tokens
@@ -1478,7 +1478,7 @@ class DeepseekV2ForCausalLM(nn.Module):
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
self.dp_size = get_attention_dp_size()
self.dp_size = get_local_attention_dp_size()
def determine_n_share_experts_fusion(
self, architecture: str = "DeepseekV3ForCausalLM"