Enable local attention during decode (#5479)
This commit is contained in:
@@ -142,6 +142,16 @@ def make_local_attention_virtual_batches(
|
||||
seqlens_k_local: Key sequence lengths for local attention
|
||||
block_table_local: Block table for local attention
|
||||
"""
|
||||
# Adjust attention_chunk_size based on the actual sequence length
|
||||
# to avoid index out of bounds errors
|
||||
max_seq_len = seq_lens_np.max()
|
||||
effective_chunk_size = min(attn_chunk_size, max_seq_len)
|
||||
# Make sure effective_chunk_size is divisible by page_size
|
||||
effective_chunk_size = (effective_chunk_size // page_size) * page_size
|
||||
if effective_chunk_size < page_size:
|
||||
effective_chunk_size = page_size
|
||||
attn_chunk_size = effective_chunk_size
|
||||
|
||||
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
|
||||
actual_batch_size = seq_lens_np.shape[0]
|
||||
|
||||
@@ -344,6 +354,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||
]
|
||||
|
||||
self._init_local_attn_metadata(metadata, device)
|
||||
else:
|
||||
# Normal Decode
|
||||
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
||||
@@ -357,6 +369,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||
]
|
||||
|
||||
self._init_local_attn_metadata(metadata, device)
|
||||
elif forward_batch.forward_mode.is_target_verify():
|
||||
metadata.cache_seqlens_int32 = (
|
||||
forward_batch.seq_lens + self.speculative_num_draft_tokens
|
||||
@@ -405,49 +419,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
metadata.cu_seqlens_q = metadata.cu_seqlens_k
|
||||
|
||||
# Setup local attention if enabled
|
||||
if (
|
||||
self.attention_chunk_size is not None
|
||||
and forward_batch.forward_mode == ForwardMode.EXTEND
|
||||
):
|
||||
# Convert tensors to numpy for local attention processing
|
||||
cu_seqlens_q_np = metadata.cu_seqlens_q.cpu().numpy()
|
||||
seq_lens_np = metadata.cache_seqlens_int32.cpu().numpy()
|
||||
|
||||
# Adjust attention_chunk_size based on the actual sequence length
|
||||
# to avoid index out of bounds errors
|
||||
max_seq_len = seq_lens_np.max()
|
||||
effective_chunk_size = min(self.attention_chunk_size, max_seq_len)
|
||||
# Make sure effective_chunk_size is divisible by page_size
|
||||
effective_chunk_size = (
|
||||
effective_chunk_size // self.page_size
|
||||
) * self.page_size
|
||||
if effective_chunk_size < self.page_size:
|
||||
effective_chunk_size = self.page_size
|
||||
|
||||
# Create local attention metadata
|
||||
(
|
||||
seqlens_q_local_np,
|
||||
cu_seqlens_q_local_np,
|
||||
seqlens_k_local_np,
|
||||
block_table_local,
|
||||
) = make_local_attention_virtual_batches(
|
||||
effective_chunk_size,
|
||||
cu_seqlens_q_np,
|
||||
seq_lens_np,
|
||||
metadata.page_table,
|
||||
self.page_size,
|
||||
)
|
||||
|
||||
local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
||||
local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(
|
||||
device
|
||||
),
|
||||
local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
|
||||
local_block_table=block_table_local,
|
||||
local_max_query_len=seqlens_q_local_np.max(),
|
||||
local_max_seq_len=seqlens_k_local_np.max(),
|
||||
)
|
||||
metadata.local_attn_metadata = local_metadata
|
||||
if forward_batch.forward_mode == ForwardMode.EXTEND:
|
||||
self._init_local_attn_metadata(metadata, device)
|
||||
|
||||
# Encoder metadata for cross attention
|
||||
if forward_batch.encoder_lens is not None:
|
||||
@@ -704,6 +677,10 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
# Use precomputed metadata across all layers
|
||||
metadata = self.forward_metadata
|
||||
local_attn_metadata = getattr(metadata, "local_attn_metadata", None)
|
||||
use_local_attention = (
|
||||
self.attention_chunk_size is not None and local_attn_metadata is not None
|
||||
)
|
||||
|
||||
# Calculate window size (can be moved to metadata if layer properties don't change)
|
||||
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
||||
@@ -738,33 +715,60 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
||||
)
|
||||
|
||||
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
if layer.is_cross_attention:
|
||||
page_table = metadata.encoder_page_table
|
||||
cache_seqlens = metadata.encoder_lens_int32
|
||||
cu_seqlens_k = metadata.encoder_cu_seqlens_k
|
||||
window_size = (-1, -1)
|
||||
# Always use non-chunked logic for cross-attention
|
||||
o = flash_attn_with_kvcache(
|
||||
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
page_table=metadata.encoder_page_table,
|
||||
cache_seqlens=metadata.encoder_lens_int32,
|
||||
cu_seqlens_q=metadata.cu_seqlens_q,
|
||||
cu_seqlens_k_new=metadata.encoder_cu_seqlens_k,
|
||||
max_seqlen_q=1,
|
||||
softmax_scale=layer.scaling,
|
||||
causal=False,
|
||||
window_size=(-1, -1),
|
||||
softcap=layer.logit_cap,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
)
|
||||
elif use_local_attention:
|
||||
# Use chunked (local) attention batching for self-attention
|
||||
o = flash_attn_with_kvcache(
|
||||
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
page_table=local_attn_metadata.local_block_table,
|
||||
cache_seqlens=local_attn_metadata.local_seqused_k,
|
||||
cu_seqlens_q=local_attn_metadata.local_query_start_loc,
|
||||
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
||||
max_seqlen_q=local_attn_metadata.local_max_query_len,
|
||||
softmax_scale=layer.scaling,
|
||||
causal=True,
|
||||
window_size=(-1, -1),
|
||||
softcap=layer.logit_cap,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
)
|
||||
else:
|
||||
page_table = metadata.page_table
|
||||
cache_seqlens = metadata.cache_seqlens_int32
|
||||
cu_seqlens_k = metadata.cu_seqlens_k
|
||||
|
||||
o = flash_attn_with_kvcache(
|
||||
q=q_reshaped,
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
page_table=page_table,
|
||||
cache_seqlens=cache_seqlens,
|
||||
cu_seqlens_q=metadata.cu_seqlens_q,
|
||||
cu_seqlens_k_new=cu_seqlens_k,
|
||||
max_seqlen_q=1,
|
||||
softmax_scale=layer.scaling,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
softcap=layer.logit_cap,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
)
|
||||
# Default: single-token self-attention
|
||||
o = flash_attn_with_kvcache(
|
||||
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
page_table=metadata.page_table,
|
||||
cache_seqlens=metadata.cache_seqlens_int32,
|
||||
cu_seqlens_q=metadata.cu_seqlens_q,
|
||||
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
||||
max_seqlen_q=1,
|
||||
softmax_scale=layer.scaling,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
softcap=layer.logit_cap,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
)
|
||||
else:
|
||||
# Do absorbed multi-latent attention
|
||||
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
@@ -986,6 +990,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
seq_lens = seq_lens[:bs]
|
||||
seq_lens_cpu = seq_lens_cpu[:bs]
|
||||
req_pool_indices = req_pool_indices[:bs]
|
||||
device = seq_lens.device
|
||||
|
||||
if forward_mode.is_decode_or_idle():
|
||||
metadata = self.decode_cuda_graph_metadata[bs]
|
||||
|
||||
@@ -1012,6 +1018,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
]
|
||||
|
||||
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
||||
|
||||
self._init_local_attn_metadata(metadata, device)
|
||||
else:
|
||||
# Normal Decode
|
||||
max_len = seq_lens_cpu.max().item()
|
||||
@@ -1035,6 +1043,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
||||
metadata.page_table[:, max_seq_pages:].fill_(0)
|
||||
|
||||
self._init_local_attn_metadata(metadata, device)
|
||||
elif forward_mode.is_target_verify():
|
||||
metadata = self.target_verify_metadata[bs]
|
||||
metadata.cache_seqlens_int32.copy_(
|
||||
@@ -1085,6 +1094,42 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
"""Get the fill value for sequence length in CUDA graph."""
|
||||
return 0
|
||||
|
||||
def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device):
|
||||
"""Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
|
||||
if self.attention_chunk_size is None:
|
||||
metadata.local_attn_metadata = None
|
||||
return
|
||||
|
||||
cu_seqlens_q = metadata.cu_seqlens_q
|
||||
cache_seqlens_int32 = metadata.cache_seqlens_int32
|
||||
page_table = metadata.page_table
|
||||
if cu_seqlens_q is None or cache_seqlens_int32 is None or page_table is None:
|
||||
metadata.local_attn_metadata = None
|
||||
return
|
||||
|
||||
cu_seqlens_q_np = cu_seqlens_q.cpu().numpy()
|
||||
seq_lens_np = cache_seqlens_int32.cpu().numpy()
|
||||
(
|
||||
seqlens_q_local_np,
|
||||
cu_seqlens_q_local_np,
|
||||
seqlens_k_local_np,
|
||||
block_table_local,
|
||||
) = make_local_attention_virtual_batches(
|
||||
self.attention_chunk_size,
|
||||
cu_seqlens_q_np,
|
||||
seq_lens_np,
|
||||
page_table,
|
||||
self.page_size,
|
||||
)
|
||||
local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
||||
local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(device),
|
||||
local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
|
||||
local_block_table=block_table_local.to(device),
|
||||
local_max_query_len=int(seqlens_q_local_np.max()),
|
||||
local_max_seq_len=int(seqlens_k_local_np.max()),
|
||||
)
|
||||
metadata.local_attn_metadata = local_metadata
|
||||
|
||||
|
||||
class FlashAttentionMultiStepBackend:
|
||||
|
||||
|
||||
Reference in New Issue
Block a user