saving hidden_states.clone() (#7705)
This commit is contained in:
@@ -436,8 +436,8 @@ class LogitsProcessor(nn.Module):
|
|||||||
if self.do_tensor_parallel_all_gather_dp_attn:
|
if self.do_tensor_parallel_all_gather_dp_attn:
|
||||||
logits_metadata.compute_dp_attention_metadata(hidden_states)
|
logits_metadata.compute_dp_attention_metadata(hidden_states)
|
||||||
hidden_states, local_hidden_states = (
|
hidden_states, local_hidden_states = (
|
||||||
logits_metadata.gathered_buffer,
|
torch.empty_like(logits_metadata.gathered_buffer),
|
||||||
hidden_states.clone(),
|
hidden_states,
|
||||||
)
|
)
|
||||||
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
|
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
|
||||||
|
|
||||||
|
|||||||
@@ -1840,11 +1840,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
hidden_states, residual, forward_batch
|
hidden_states, residual, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.enable_dp_attention and self.speculative_algorithm.is_eagle():
|
|
||||||
# NOTE: this line resolves the degradation of MTP reception rate for non-zero DP ranks.
|
|
||||||
# See discussion here (https://github.com/sgl-project/sglang/pull/6081#discussion_r2147452251).
|
|
||||||
hidden_states = hidden_states.clone()
|
|
||||||
|
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
def op_comm_prepare_attn(
|
def op_comm_prepare_attn(
|
||||||
|
|||||||
Reference in New Issue
Block a user