feat: mtp support dp-attention (#6081)

Co-authored-by: austindeng <austindeng@tencent.com>
Co-authored-by: tianqilin.99 <tianqilin.99@bytedance.com>
Co-authored-by: Qiaolin Yu <liin1211@outlook.com>
Co-authored-by: ch-wan <cwan39@gatech.edu>
This commit is contained in:
u4lr451
2025-06-17 15:33:28 +08:00
committed by GitHub
parent 8a10c4c3d9
commit 10d60cd41b
22 changed files with 641 additions and 151 deletions

View File

@@ -1399,7 +1399,9 @@ class DeepseekV2DecoderLayer(nn.Module):
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
self.speculative_algorithm = global_server_args_dict["speculative_algorithm"]
self.layer_id = layer_id
self.is_nextn = is_nextn
self.self_attn = DeepseekV2AttentionMLA(
config=config,
hidden_size=self.hidden_size,
@@ -1500,6 +1502,11 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states, residual, forward_batch
)
if self.enable_dp_attention and self.speculative_algorithm.is_eagle():
# NOTE: this line resolves the degradation of MTP reception rate for non-zero DP ranks.
# See discussion here (https://github.com/sgl-project/sglang/pull/6081#discussion_r2147452251).
hidden_states = hidden_states.clone()
return hidden_states, residual
def op_comm_prepare_attn(