Remove q concat in FA3 backend for DeepSeek decode (#5638)
This commit is contained in:
@@ -62,6 +62,7 @@ class AttentionBackend(ABC):
|
|||||||
layer: RadixAttention,
|
layer: RadixAttention,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
save_kv_cache: bool = True,
|
save_kv_cache: bool = True,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Run forward on an attention layer."""
|
"""Run forward on an attention layer."""
|
||||||
if forward_batch.forward_mode.is_decode():
|
if forward_batch.forward_mode.is_decode():
|
||||||
@@ -72,6 +73,7 @@ class AttentionBackend(ABC):
|
|||||||
layer,
|
layer,
|
||||||
forward_batch,
|
forward_batch,
|
||||||
save_kv_cache=save_kv_cache,
|
save_kv_cache=save_kv_cache,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.forward_extend(
|
return self.forward_extend(
|
||||||
@@ -81,6 +83,7 @@ class AttentionBackend(ABC):
|
|||||||
layer,
|
layer,
|
||||||
forward_batch,
|
forward_batch,
|
||||||
save_kv_cache=save_kv_cache,
|
save_kv_cache=save_kv_cache,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_decode(
|
def forward_decode(
|
||||||
|
|||||||
@@ -623,6 +623,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
layer: RadixAttention,
|
layer: RadixAttention,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
save_kv_cache=True,
|
save_kv_cache=True,
|
||||||
|
# For multi-head latent attention
|
||||||
|
q_rope: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
if k is not None:
|
if k is not None:
|
||||||
assert v is not None
|
assert v is not None
|
||||||
@@ -815,9 +817,15 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
c_kv_cache = c_kv.view(
|
c_kv_cache = c_kv.view(
|
||||||
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
|
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
|
||||||
)
|
)
|
||||||
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
if q_rope is not None:
|
||||||
q_nope = q_all[:, :, : layer.v_head_dim]
|
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
||||||
q_rope = q_all[:, :, layer.v_head_dim :]
|
q_rope = q_rope.view(
|
||||||
|
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||||
|
q_nope = q_all[:, :, : layer.v_head_dim]
|
||||||
|
q_rope = q_all[:, :, layer.v_head_dim :]
|
||||||
|
|
||||||
result = flash_attn_with_kvcache(
|
result = flash_attn_with_kvcache(
|
||||||
q=q_rope,
|
q=q_rope,
|
||||||
@@ -877,6 +885,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
layer: RadixAttention,
|
layer: RadixAttention,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
save_kv_cache=True,
|
save_kv_cache=True,
|
||||||
|
# For multi-head latent attention
|
||||||
|
q_rope: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if k is not None:
|
if k is not None:
|
||||||
assert v is not None
|
assert v is not None
|
||||||
@@ -1047,9 +1057,15 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
|
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
|
||||||
)
|
)
|
||||||
|
|
||||||
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
if q_rope is not None:
|
||||||
q_nope = q_all[:, :, : layer.v_head_dim]
|
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
||||||
q_rope = q_all[:, :, layer.v_head_dim :]
|
q_rope = q_rope.view(
|
||||||
|
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||||
|
q_nope = q_all[:, :, : layer.v_head_dim]
|
||||||
|
q_rope = q_all[:, :, layer.v_head_dim :]
|
||||||
max_seqlen_q = metadata.max_seq_len_q
|
max_seqlen_q = metadata.max_seq_len_q
|
||||||
|
|
||||||
result = flash_attn_with_kvcache(
|
result = flash_attn_with_kvcache(
|
||||||
|
|||||||
@@ -87,6 +87,7 @@ class RadixAttention(nn.Module):
|
|||||||
v,
|
v,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
save_kv_cache: bool = True,
|
save_kv_cache: bool = True,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
if k is not None:
|
if k is not None:
|
||||||
# For cross-layer sharing, kv can be None
|
# For cross-layer sharing, kv can be None
|
||||||
@@ -95,5 +96,11 @@ class RadixAttention(nn.Module):
|
|||||||
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
||||||
|
|
||||||
return forward_batch.attn_backend.forward(
|
return forward_batch.attn_backend.forward(
|
||||||
q, k, v, self, forward_batch, save_kv_cache
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
self,
|
||||||
|
forward_batch,
|
||||||
|
save_kv_cache,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -751,10 +751,15 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
|
|
||||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||||
|
|
||||||
q = torch.cat([q_nope_out, q_pe], dim=-1)
|
|
||||||
k = torch.cat([k_nope, k_pe], dim=-1)
|
k = torch.cat([k_nope, k_pe], dim=-1)
|
||||||
|
|
||||||
attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
|
if self.attention_backend == "fa3":
|
||||||
|
attn_output = self.attn_mqa(
|
||||||
|
q_nope_out, k, k_nope, forward_batch, q_rope=q_pe
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
q = torch.cat([q_nope_out, q_pe], dim=-1)
|
||||||
|
attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
|
||||||
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
||||||
|
|
||||||
if self.use_deep_gemm_bmm:
|
if self.use_deep_gemm_bmm:
|
||||||
|
|||||||
Reference in New Issue
Block a user