feat: mtp support dp-attention (#6081)
Co-authored-by: austindeng <austindeng@tencent.com> Co-authored-by: tianqilin.99 <tianqilin.99@bytedance.com> Co-authored-by: Qiaolin Yu <liin1211@outlook.com> Co-authored-by: ch-wan <cwan39@gatech.edu>
This commit is contained in:
@@ -1399,7 +1399,9 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
|
||||
self.speculative_algorithm = global_server_args_dict["speculative_algorithm"]
|
||||
self.layer_id = layer_id
|
||||
self.is_nextn = is_nextn
|
||||
self.self_attn = DeepseekV2AttentionMLA(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
@@ -1500,6 +1502,11 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
hidden_states, residual, forward_batch
|
||||
)
|
||||
|
||||
if self.enable_dp_attention and self.speculative_algorithm.is_eagle():
|
||||
# NOTE: this line resolves the degradation of MTP reception rate for non-zero DP ranks.
|
||||
# See discussion here (https://github.com/sgl-project/sglang/pull/6081#discussion_r2147452251).
|
||||
hidden_states = hidden_states.clone()
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
def op_comm_prepare_attn(
|
||||
|
||||
Reference in New Issue
Block a user