saving hidden_states.clone() (#7705)

This commit is contained in:
Cheng Wan
2025-07-04 20:07:42 -07:00
committed by GitHub
parent 1964c325de
commit cb432f1770
2 changed files with 2 additions and 7 deletions

View File

@@ -436,8 +436,8 @@ class LogitsProcessor(nn.Module):
if self.do_tensor_parallel_all_gather_dp_attn:
logits_metadata.compute_dp_attention_metadata(hidden_states)
hidden_states, local_hidden_states = (
logits_metadata.gathered_buffer,
hidden_states.clone(),
torch.empty_like(logits_metadata.gathered_buffer),
hidden_states,
)
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)

View File

@@ -1840,11 +1840,6 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states, residual, forward_batch
)
if self.enable_dp_attention and self.speculative_algorithm.is_eagle():
# NOTE: this line resolves the degradation of MTP reception rate for non-zero DP ranks.
# See discussion here (https://github.com/sgl-project/sglang/pull/6081#discussion_r2147452251).
hidden_states = hidden_states.clone()
return hidden_states, residual
def op_comm_prepare_attn(