Remove q concat in FA3 backend for DeepSeek decode (#5638)
This commit is contained in:
@@ -623,6 +623,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
# For multi-head latent attention
|
||||
q_rope: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
@@ -815,9 +817,15 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
c_kv_cache = c_kv.view(
|
||||
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
|
||||
)
|
||||
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
q_nope = q_all[:, :, : layer.v_head_dim]
|
||||
q_rope = q_all[:, :, layer.v_head_dim :]
|
||||
if q_rope is not None:
|
||||
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
||||
q_rope = q_rope.view(
|
||||
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
||||
)
|
||||
else:
|
||||
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
q_nope = q_all[:, :, : layer.v_head_dim]
|
||||
q_rope = q_all[:, :, layer.v_head_dim :]
|
||||
|
||||
result = flash_attn_with_kvcache(
|
||||
q=q_rope,
|
||||
@@ -877,6 +885,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
# For multi-head latent attention
|
||||
q_rope: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
@@ -1047,9 +1057,15 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
|
||||
)
|
||||
|
||||
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
q_nope = q_all[:, :, : layer.v_head_dim]
|
||||
q_rope = q_all[:, :, layer.v_head_dim :]
|
||||
if q_rope is not None:
|
||||
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
||||
q_rope = q_rope.view(
|
||||
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
||||
)
|
||||
else:
|
||||
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
q_nope = q_all[:, :, : layer.v_head_dim]
|
||||
q_rope = q_all[:, :, layer.v_head_dim :]
|
||||
max_seqlen_q = metadata.max_seq_len_q
|
||||
|
||||
result = flash_attn_with_kvcache(
|
||||
|
||||
Reference in New Issue
Block a user