saving hidden_states.clone() (#7705)
This commit is contained in:
@@ -436,8 +436,8 @@ class LogitsProcessor(nn.Module):
|
||||
if self.do_tensor_parallel_all_gather_dp_attn:
|
||||
logits_metadata.compute_dp_attention_metadata(hidden_states)
|
||||
hidden_states, local_hidden_states = (
|
||||
logits_metadata.gathered_buffer,
|
||||
hidden_states.clone(),
|
||||
torch.empty_like(logits_metadata.gathered_buffer),
|
||||
hidden_states,
|
||||
)
|
||||
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
|
||||
|
||||
|
||||
@@ -1840,11 +1840,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
hidden_states, residual, forward_batch
|
||||
)
|
||||
|
||||
if self.enable_dp_attention and self.speculative_algorithm.is_eagle():
|
||||
# NOTE: this line resolves the degradation of MTP reception rate for non-zero DP ranks.
|
||||
# See discussion here (https://github.com/sgl-project/sglang/pull/6081#discussion_r2147452251).
|
||||
hidden_states = hidden_states.clone()
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
def op_comm_prepare_attn(
|
||||
|
||||
Reference in New Issue
Block a user