Remove q concat in FA3 backend for DeepSeek decode (#5638)
This commit is contained in:
@@ -87,6 +87,7 @@ class RadixAttention(nn.Module):
|
||||
v,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
if k is not None:
|
||||
# For cross-layer sharing, kv can be None
|
||||
@@ -95,5 +96,11 @@ class RadixAttention(nn.Module):
|
||||
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
||||
|
||||
return forward_batch.attn_backend.forward(
|
||||
q, k, v, self, forward_batch, save_kv_cache
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
self,
|
||||
forward_batch,
|
||||
save_kv_cache,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user