Reduce computation and communication in DP attention (#4521)

This commit is contained in:
Cheng Wan
2025-03-18 16:41:36 -04:00
committed by GitHub
parent 9e0186f352
commit 3196999f63
5 changed files with 70 additions and 80 deletions

View File

@@ -33,7 +33,7 @@ from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
decode_attention_fwd_grouped_rope,
)
from sglang.srt.layers.dp_attention import (
dp_gather,
dp_gather_partial,
dp_scatter,
get_attention_dp_size,
get_attention_tp_rank,
@@ -939,11 +939,47 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
if residual is None:
if hidden_states.shape[0] == 0:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
# Self Attention
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
# Gather
if get_tensor_model_parallel_world_size() > 1:
# all gather and all reduce
if self.dp_size != 1:
if get_attention_tp_rank() == 0:
hidden_states += residual
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer,
hidden_states,
)
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
dp_scatter(residual, hidden_states, forward_batch)
hidden_states = self.post_attention_layernorm(hidden_states)
else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
else:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
# Fully Connected
hidden_states = self.mlp(hidden_states)
# Scatter
if self.dp_size != 1:
@@ -955,31 +991,6 @@ class DeepseekV2DecoderLayer(nn.Module):
)
dp_scatter(hidden_states, global_hidden_states, forward_batch)
# Self Attention
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
# Gather
if get_tensor_model_parallel_world_size() > 1:
# all gather and all reduce
if self.dp_size != 1:
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer,
hidden_states,
)
dp_gather(
hidden_states, local_hidden_states, forward_batch, self.layer_id
)
else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
# Fully Connected
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@@ -1025,18 +1036,6 @@ class DeepseekV2Model(nn.Module):
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
# Gather
if self.dp_size != 1:
input_ids, local_input_ids = (
torch.empty(
(forward_batch.gathered_buffer.shape[0],),
dtype=input_ids.dtype,
device=input_ids.device,
),
input_ids,
)
dp_gather(input_ids, local_input_ids, forward_batch, "embedding")
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
@@ -1087,15 +1086,6 @@ class DeepseekV2ForCausalLM(nn.Module):
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
if self.dp_size != 1:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states, global_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
dp_scatter(hidden_states, global_hidden_states, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)