Improve DP attention (#4390)

Co-authored-by: dhou-xai <dhou@x.ai>
Co-authored-by: SangBin Cho <rkooo567@gmail.com>
This commit is contained in:
Lianmin Zheng
2025-03-13 08:23:56 -07:00
committed by GitHub
parent f141298a3c
commit 8e66fbecee
9 changed files with 345 additions and 226 deletions

View File

@@ -26,15 +26,20 @@ from transformers import PretrainedConfig
from vllm import _custom_ops as ops
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tp_group,
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
decode_attention_fwd_grouped_rope,
)
from sglang.srt.layers.dp_attention import (
dp_gather,
dp_scatter,
get_attention_dp_size,
get_attention_tp_rank,
get_attention_tp_size,
)
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
ColumnParallelLinear,
@@ -230,6 +235,7 @@ class DeepseekV2Attention(nn.Module):
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
layer_id=None,
reduce_results: bool = True,
prefix: str = "",
) -> None:
super().__init__()
@@ -241,10 +247,14 @@ class DeepseekV2Attention(nn.Module):
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.dp_size = get_attention_dp_size()
attn_tp_rank = get_attention_tp_rank()
attn_tp_size = get_attention_tp_size()
self.num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size()
assert num_heads % tp_size == 0
self.num_local_heads = num_heads // tp_size
assert num_heads % attn_tp_size == 0
self.num_local_heads = num_heads // attn_tp_size
self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
@@ -272,6 +282,8 @@ class DeepseekV2Attention(nn.Module):
bias=False,
quant_config=quant_config,
prefix=add_prefix("q_proj", prefix),
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
)
self.kv_a_proj_with_mqa = ReplicatedLinear(
@@ -296,6 +308,9 @@ class DeepseekV2Attention(nn.Module):
bias=False,
quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
reduce_results=reduce_results,
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
)
rope_scaling["rope_type"] = "deepseek_yarn"
self.rotary_emb = get_rope_wrapper(
@@ -330,6 +345,12 @@ class DeepseekV2Attention(nn.Module):
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
if hidden_states.shape[0] == 0:
assert (
not self.o_proj.reduce_results
), "short-circuiting allreduce will lead to hangs"
return hidden_states
if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0]
q = self.q_a_layernorm(q)
@@ -385,8 +406,8 @@ class DeepseekV2AttentionMLA(nn.Module):
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
layer_id=None,
use_dp=False,
reduce_results: bool = True,
layer_id: int = None,
prefix: str = "",
) -> None:
super().__init__()
@@ -398,96 +419,66 @@ class DeepseekV2AttentionMLA(nn.Module):
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.dp_size = get_attention_dp_size()
attn_tp_rank = get_attention_tp_rank()
attn_tp_size = get_attention_tp_size()
self.num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size()
assert num_heads % tp_size == 0
self.num_local_heads = num_heads if use_dp else num_heads // tp_size
assert num_heads % attn_tp_size == 0
self.num_local_heads = num_heads // attn_tp_size
self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
if use_dp:
# For data parallel attention
if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(
self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=quant_config,
prefix=add_prefix("q_a_proj", prefix),
)
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
self.q_b_proj = ReplicatedLinear(
q_lora_rank,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("q_b_proj", prefix),
)
else:
self.q_proj = ReplicatedLinear(
self.hidden_size,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("q_proj", prefix),
)
self.kv_b_proj = ReplicatedLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config,
prefix=add_prefix("kv_b_proj", prefix),
)
# O projection.
self.o_proj = ReplicatedLinear(
self.num_heads * self.v_head_dim,
# For tensor parallel attention
if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(
self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
prefix=add_prefix("q_a_proj", prefix),
)
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(
q_lora_rank,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("q_b_proj", prefix),
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
)
else:
# For tensor parallel attention
if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(
self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=quant_config,
prefix=add_prefix("q_a_proj", prefix),
)
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(
q_lora_rank,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("q_b_proj", prefix),
)
else:
self.q_proj = ColumnParallelLinear(
self.hidden_size,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("q_proj", prefix),
)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config,
prefix=add_prefix("kv_b_proj", prefix),
)
# O projection.
self.o_proj = RowParallelLinear(
self.num_heads * self.v_head_dim,
self.q_proj = ColumnParallelLinear(
self.hidden_size,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
prefix=add_prefix("q_proj", prefix),
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config,
prefix=add_prefix("kv_b_proj", prefix),
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
)
# O projection.
self.o_proj = RowParallelLinear(
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=add_prefix("o_proj", prefix),
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
)
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
@@ -542,38 +533,49 @@ class DeepseekV2AttentionMLA(nn.Module):
self.w_vc = None
self.w_scale = None
self.enable_flashinfer_mla = global_server_args_dict["enable_flashinfer_mla"]
self.flashinfer_mla_disable_ragged = global_server_args_dict[
"flashinfer_mla_disable_ragged"
]
self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
def no_absorb(self, forward_batch: ForwardBatch) -> bool:
if self.enable_flashinfer_mla:
# Flashinfer MLA: Do not absorb when enabling ragged prefill
return (
not self.flashinfer_mla_disable_ragged
and forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
and forward_batch.extend_prefix_lens.sum() == 0
)
else:
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
return (
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
and forward_batch.extend_prefix_lens.sum() == 0
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
if hidden_states.shape[0] == 0:
assert (
not self.o_proj.reduce_results
), "short-circuiting allreduce will lead to hangs"
return hidden_states
def no_absorb() -> bool:
if global_server_args_dict["enable_flashinfer_mla"]:
# Flashinfer MLA: Do not absorb when enabling ragged prefill
return (
not global_server_args_dict["flashinfer_mla_disable_ragged"]
and forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
and forward_batch.extend_prefix_lens.sum() == 0
)
else:
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
return (
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
and forward_batch.extend_prefix_lens.sum() == 0
)
if no_absorb():
if self.no_absorb(forward_batch):
return self.forward_normal(positions, hidden_states, forward_batch)
else:
if _is_hip:
if (
os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
self.rocm_fused_decode_mla
and forward_batch.forward_mode.is_decode()
):
return self.forward_absorb_fused_mla_rope(
@@ -845,34 +847,6 @@ class DeepseekV2AttentionMLA(nn.Module):
return output
def all_gather(
input_tensor: torch.Tensor, forward_batch: ForwardBatch, rank, world_size, group
):
all_lens = forward_batch.global_num_tokens_cpu
max_len = max(forward_batch.global_num_tokens_cpu)
if world_size == 1:
return input_tensor, 0, all_lens[0]
padded_tensor = torch.nn.functional.pad(
input_tensor, (0, 0, 0, max_len - input_tensor.shape[0])
)
group.all_gather_into_tensor(forward_batch.gathered_buffer, padded_tensor)
gathered_tensors = torch.concat(
[
forward_batch.gathered_buffer[i * max_len : i * max_len + all_lens[i]]
for i in range(world_size)
]
)
start_index = 0 if rank == 0 else sum(all_lens[:rank])
end_index = start_index + all_lens[rank]
return gathered_tensors, start_index, end_index
class DeepseekV2DecoderLayer(nn.Module):
def __init__(
@@ -888,14 +862,10 @@ class DeepseekV2DecoderLayer(nn.Module):
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.enable_dp_attention = (
not global_server_args_dict["disable_mla"]
and global_server_args_dict["enable_dp_attention"]
)
if self.enable_dp_attention:
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_group = get_tp_group()
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
self.layer_id = layer_id
self.dp_size = get_attention_dp_size()
if not global_server_args_dict["disable_mla"]:
self.self_attn = DeepseekV2AttentionMLA(
config=config,
@@ -913,7 +883,7 @@ class DeepseekV2DecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
layer_id=layer_id,
use_dp=self.enable_dp_attention,
reduce_results=False,
prefix=add_prefix("self_attn", prefix),
)
else:
@@ -933,8 +903,10 @@ class DeepseekV2DecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
layer_id=layer_id,
reduce_results=False,
prefix=add_prefix("self_attn", prefix),
)
if is_nextn or (
config.n_routed_experts is not None
and layer_id >= config.first_k_dense_replace
@@ -965,33 +937,47 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
if not forward_batch.forward_mode.is_idle():
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
# Scatter
if self.dp_size != 1:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states, global_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
dp_scatter(hidden_states, global_hidden_states, forward_batch)
# Self Attention
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
# Gather
if get_tensor_model_parallel_world_size() > 1:
# all gather and all reduce
if self.dp_size != 1:
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer,
hidden_states,
)
dp_gather(
hidden_states, local_hidden_states, forward_batch, self.layer_id
)
else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
# Fully Connected
if self.enable_dp_attention:
hidden_states, start_idx, end_idx = all_gather(
hidden_states, forward_batch, self.tp_rank, self.tp_size, self.tp_group
)
hidden_states = self.mlp(hidden_states)
hidden_states = hidden_states[start_idx:end_idx]
else:
hidden_states = self.mlp(hidden_states)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@@ -1027,12 +1013,27 @@ class DeepseekV2Model(nn.Module):
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.dp_size = get_attention_dp_size()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
# Gather
if self.dp_size != 1:
input_ids, local_input_ids = (
torch.empty(
(forward_batch.gathered_buffer.shape[0],),
dtype=input_ids.dtype,
device=input_ids.device,
),
input_ids,
)
dp_gather(input_ids, local_input_ids, forward_batch, "embedding")
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
@@ -1059,22 +1060,14 @@ class DeepseekV2ForCausalLM(nn.Module):
self.model = DeepseekV2Model(
config, quant_config, prefix=add_prefix("model", prefix)
)
if global_server_args_dict["enable_dp_attention"]:
self.lm_head = ReplicatedLinear(
config.hidden_size,
config.vocab_size,
bias=False,
prefix=add_prefix("lm_head", prefix),
)
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
self.logits_processor = LogitsProcessor(config)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
self.logits_processor = LogitsProcessor(config)
self.dp_size = get_attention_dp_size()
@torch.no_grad()
def forward(
@@ -1084,6 +1077,16 @@ class DeepseekV2ForCausalLM(nn.Module):
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch)
if self.dp_size != 1:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states, global_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
dp_scatter(hidden_states, global_hidden_states, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)