Support Page Size > 1 for FA3 (#4832)
Co-authored-by: Qingquan Song <ustcsqq@gmail.com> Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
This commit is contained in:
@@ -57,6 +57,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
self.device = model_runner.device
|
||||
self.decode_cuda_graph_metadata = {}
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
self.page_size = model_runner.page_size
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
"""Initialize forward metadata to cache repetitive calculations."""
|
||||
@@ -78,6 +79,17 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||
]
|
||||
|
||||
# Precompute strided indices
|
||||
# [0, page_size, 2 * page_size, ...]
|
||||
if self.page_size > 1:
|
||||
self.strided_indices = torch.arange(
|
||||
0, metadata.page_table.shape[1], self.page_size, device=self.device
|
||||
)
|
||||
metadata.page_table = (
|
||||
metadata.page_table[:, self.strided_indices] // self.page_size
|
||||
)
|
||||
|
||||
if forward_batch.forward_mode == ForwardMode.DECODE:
|
||||
# Precompute cumulative sequence lengths
|
||||
metadata.cu_seqlens_q = torch.arange(
|
||||
@@ -132,11 +144,21 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
)
|
||||
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
|
||||
key_cache = key_cache.view(
|
||||
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
|
||||
)
|
||||
value_cache = value_cache.view(
|
||||
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
||||
)
|
||||
|
||||
page_table = metadata.page_table
|
||||
|
||||
o = flash_attn_with_kvcache(
|
||||
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
k_cache=key_cache.unsqueeze(1),
|
||||
v_cache=value_cache.unsqueeze(1),
|
||||
page_table=metadata.page_table,
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
page_table=page_table,
|
||||
cache_seqlens=metadata.cache_seqlens_int32,
|
||||
cu_seqlens_q=metadata.cu_seqlens_q,
|
||||
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
||||
@@ -175,13 +197,11 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
# Get KV cache
|
||||
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
|
||||
# Use precomputed metadata
|
||||
metadata = self.forward_metadata
|
||||
|
||||
# Pre-reshape query tensor
|
||||
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
|
||||
# Calculate window size (can be moved to metadata if layer properties don't change)
|
||||
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
||||
# here is two side inclusive
|
||||
@@ -191,11 +211,20 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
else (-1, -1)
|
||||
)
|
||||
# Run attention with precomputed values
|
||||
key_cache = key_cache.view(
|
||||
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
|
||||
)
|
||||
value_cache = value_cache.view(
|
||||
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
||||
)
|
||||
|
||||
page_table = metadata.page_table
|
||||
|
||||
o = flash_attn_with_kvcache(
|
||||
q=q_reshaped,
|
||||
k_cache=key_cache.unsqueeze(1),
|
||||
v_cache=value_cache.unsqueeze(1),
|
||||
page_table=metadata.page_table,
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
page_table=page_table,
|
||||
cache_seqlens=metadata.cache_seqlens_int32,
|
||||
cu_seqlens_q=metadata.cu_seqlens_q,
|
||||
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
||||
@@ -207,7 +236,6 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
k_descale=layer.k_scale,
|
||||
v_descale=layer.v_scale,
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int):
|
||||
@@ -223,7 +251,13 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
self.decode_cuda_graph_metadata = {
|
||||
# Page table for token mapping (batch_size, max_context_len)
|
||||
"page_table": torch.zeros(
|
||||
max_bs, self.max_context_len, dtype=torch.int32, device=self.device
|
||||
max_bs,
|
||||
(self.max_context_len + self.page_size - 1) // self.page_size,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
"strided_indices": torch.arange(
|
||||
0, self.max_context_len, self.page_size, device=self.device
|
||||
),
|
||||
}
|
||||
|
||||
@@ -252,6 +286,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
|
||||
req_pool_indices, :
|
||||
]
|
||||
|
||||
if forward_mode == ForwardMode.DECODE:
|
||||
# Precompute cumulative sequence lengths
|
||||
metadata.cu_seqlens_q = torch.arange(
|
||||
@@ -287,14 +322,11 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
torch.cumsum(seq_lens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
||||
)
|
||||
|
||||
# Only zero out the part out of max_len_k
|
||||
metadata.page_table[:, metadata.max_seq_len_k :].fill_(0)
|
||||
# Then do the copy
|
||||
metadata.page_table[:, : metadata.max_seq_len_k].copy_(
|
||||
self.req_to_token[req_pool_indices[:bs], : metadata.max_seq_len_k]
|
||||
)
|
||||
|
||||
self.forward_decode_metadata = metadata
|
||||
metadata.page_table = self.req_to_token[
|
||||
:, self.decode_cuda_graph_metadata["strided_indices"]
|
||||
]
|
||||
metadata.page_table = metadata.page_table[req_pool_indices[:bs]]
|
||||
self.forward_metadata = metadata
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
"""Get the fill value for sequence length in CUDA graph."""
|
||||
|
||||
Reference in New Issue
Block a user