Remove q concat in FA3 backend for DeepSeek decode (#5638)

This commit is contained in:
Ke Bao
2025-04-23 02:43:12 +08:00
committed by GitHub
parent 917324862e
commit 6b6e748775
4 changed files with 40 additions and 9 deletions

View File

@@ -87,6 +87,7 @@ class RadixAttention(nn.Module):
v,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
**kwargs,
):
if k is not None:
# For cross-layer sharing, kv can be None
@@ -95,5 +96,11 @@ class RadixAttention(nn.Module):
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
return forward_batch.attn_backend.forward(
q, k, v, self, forward_batch, save_kv_cache
q,
k,
v,
self,
forward_batch,
save_kv_cache,
**kwargs,
)