Remove q concat in FA3 backend for DeepSeek decode (#5638)
This commit is contained in:
@@ -62,6 +62,7 @@ class AttentionBackend(ABC):
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Run forward on an attention layer."""
|
||||
if forward_batch.forward_mode.is_decode():
|
||||
@@ -72,6 +73,7 @@ class AttentionBackend(ABC):
|
||||
layer,
|
||||
forward_batch,
|
||||
save_kv_cache=save_kv_cache,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
return self.forward_extend(
|
||||
@@ -81,6 +83,7 @@ class AttentionBackend(ABC):
|
||||
layer,
|
||||
forward_batch,
|
||||
save_kv_cache=save_kv_cache,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def forward_decode(
|
||||
|
||||
@@ -623,6 +623,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
# For multi-head latent attention
|
||||
q_rope: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
@@ -815,9 +817,15 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
c_kv_cache = c_kv.view(
|
||||
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
|
||||
)
|
||||
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
q_nope = q_all[:, :, : layer.v_head_dim]
|
||||
q_rope = q_all[:, :, layer.v_head_dim :]
|
||||
if q_rope is not None:
|
||||
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
||||
q_rope = q_rope.view(
|
||||
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
||||
)
|
||||
else:
|
||||
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
q_nope = q_all[:, :, : layer.v_head_dim]
|
||||
q_rope = q_all[:, :, layer.v_head_dim :]
|
||||
|
||||
result = flash_attn_with_kvcache(
|
||||
q=q_rope,
|
||||
@@ -877,6 +885,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
# For multi-head latent attention
|
||||
q_rope: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
@@ -1047,9 +1057,15 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
|
||||
)
|
||||
|
||||
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
q_nope = q_all[:, :, : layer.v_head_dim]
|
||||
q_rope = q_all[:, :, layer.v_head_dim :]
|
||||
if q_rope is not None:
|
||||
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
||||
q_rope = q_rope.view(
|
||||
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
||||
)
|
||||
else:
|
||||
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
q_nope = q_all[:, :, : layer.v_head_dim]
|
||||
q_rope = q_all[:, :, layer.v_head_dim :]
|
||||
max_seqlen_q = metadata.max_seq_len_q
|
||||
|
||||
result = flash_attn_with_kvcache(
|
||||
|
||||
@@ -87,6 +87,7 @@ class RadixAttention(nn.Module):
|
||||
v,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
if k is not None:
|
||||
# For cross-layer sharing, kv can be None
|
||||
@@ -95,5 +96,11 @@ class RadixAttention(nn.Module):
|
||||
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
||||
|
||||
return forward_batch.attn_backend.forward(
|
||||
q, k, v, self, forward_batch, save_kv_cache
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
self,
|
||||
forward_batch,
|
||||
save_kv_cache,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -751,10 +751,15 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
|
||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||
|
||||
q = torch.cat([q_nope_out, q_pe], dim=-1)
|
||||
k = torch.cat([k_nope, k_pe], dim=-1)
|
||||
|
||||
attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
|
||||
if self.attention_backend == "fa3":
|
||||
attn_output = self.attn_mqa(
|
||||
q_nope_out, k, k_nope, forward_batch, q_rope=q_pe
|
||||
)
|
||||
else:
|
||||
q = torch.cat([q_nope_out, q_pe], dim=-1)
|
||||
attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
|
||||
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
||||
|
||||
if self.use_deep_gemm_bmm:
|
||||
|
||||
Reference in New Issue
Block a user