MLA prefill w/o weight absorption (#2349)

This commit is contained in:
Ke Bao
2024-12-05 01:50:28 +08:00
committed by GitHub
parent eb0c1f5373
commit ec52464dde
8 changed files with 166 additions and 36 deletions

View File

@@ -48,11 +48,13 @@ class RadixAttention(nn.Module):
self.sliding_window_size = sliding_window_size or -1
self.is_cross_attention = is_cross_attention
def forward(self, q, k, v, forward_batch: ForwardBatch):
def forward(self, q, k, v, forward_batch: ForwardBatch, save_kv_cache=True):
if k is not None:
# For cross-layer sharing, kv can be None
assert v is not None
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
return forward_batch.attn_backend.forward(q, k, v, self, forward_batch)
return forward_batch.attn_backend.forward(
q, k, v, self, forward_batch, save_kv_cache
)