Add unit test on page_size > 1 and mla and integration test for Flash Attention 3 (#4760)

This commit is contained in:
Yubo Wang
2025-04-07 23:20:51 -07:00
committed by GitHub
parent a7c3f74bec
commit 804d9f2e4c
6 changed files with 733 additions and 224 deletions

View File

@@ -548,8 +548,9 @@ class FlashAttentionBackend(AttentionBackend):
# Use Flash Attention for prefill
if not self.use_mla:
# Do multi-head attention
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
key_cache, value_cache = kv_cache[0], kv_cache[1]
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
layer.layer_id
)
key_cache = key_cache.view(
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
)
@@ -592,7 +593,6 @@ class FlashAttentionBackend(AttentionBackend):
c_kv_cache = c_kv.view(
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
)
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
@@ -659,8 +659,10 @@ class FlashAttentionBackend(AttentionBackend):
if not self.use_mla:
# Do multi-head attention
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
key_cache, value_cache = kv_cache[0], kv_cache[1]
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
layer.layer_id
)
key_cache = key_cache.view(
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
)