Enable local attention during decode (#5479)
This commit is contained in:
@@ -142,6 +142,16 @@ def make_local_attention_virtual_batches(
|
|||||||
seqlens_k_local: Key sequence lengths for local attention
|
seqlens_k_local: Key sequence lengths for local attention
|
||||||
block_table_local: Block table for local attention
|
block_table_local: Block table for local attention
|
||||||
"""
|
"""
|
||||||
|
# Adjust attention_chunk_size based on the actual sequence length
|
||||||
|
# to avoid index out of bounds errors
|
||||||
|
max_seq_len = seq_lens_np.max()
|
||||||
|
effective_chunk_size = min(attn_chunk_size, max_seq_len)
|
||||||
|
# Make sure effective_chunk_size is divisible by page_size
|
||||||
|
effective_chunk_size = (effective_chunk_size // page_size) * page_size
|
||||||
|
if effective_chunk_size < page_size:
|
||||||
|
effective_chunk_size = page_size
|
||||||
|
attn_chunk_size = effective_chunk_size
|
||||||
|
|
||||||
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
|
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
|
||||||
actual_batch_size = seq_lens_np.shape[0]
|
actual_batch_size = seq_lens_np.shape[0]
|
||||||
|
|
||||||
@@ -344,6 +354,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||||
]
|
]
|
||||||
|
|
||||||
|
self._init_local_attn_metadata(metadata, device)
|
||||||
else:
|
else:
|
||||||
# Normal Decode
|
# Normal Decode
|
||||||
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
||||||
@@ -357,6 +369,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||||
]
|
]
|
||||||
|
|
||||||
|
self._init_local_attn_metadata(metadata, device)
|
||||||
elif forward_batch.forward_mode.is_target_verify():
|
elif forward_batch.forward_mode.is_target_verify():
|
||||||
metadata.cache_seqlens_int32 = (
|
metadata.cache_seqlens_int32 = (
|
||||||
forward_batch.seq_lens + self.speculative_num_draft_tokens
|
forward_batch.seq_lens + self.speculative_num_draft_tokens
|
||||||
@@ -405,49 +419,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
metadata.cu_seqlens_q = metadata.cu_seqlens_k
|
metadata.cu_seqlens_q = metadata.cu_seqlens_k
|
||||||
|
|
||||||
# Setup local attention if enabled
|
# Setup local attention if enabled
|
||||||
if (
|
if forward_batch.forward_mode == ForwardMode.EXTEND:
|
||||||
self.attention_chunk_size is not None
|
self._init_local_attn_metadata(metadata, device)
|
||||||
and forward_batch.forward_mode == ForwardMode.EXTEND
|
|
||||||
):
|
|
||||||
# Convert tensors to numpy for local attention processing
|
|
||||||
cu_seqlens_q_np = metadata.cu_seqlens_q.cpu().numpy()
|
|
||||||
seq_lens_np = metadata.cache_seqlens_int32.cpu().numpy()
|
|
||||||
|
|
||||||
# Adjust attention_chunk_size based on the actual sequence length
|
|
||||||
# to avoid index out of bounds errors
|
|
||||||
max_seq_len = seq_lens_np.max()
|
|
||||||
effective_chunk_size = min(self.attention_chunk_size, max_seq_len)
|
|
||||||
# Make sure effective_chunk_size is divisible by page_size
|
|
||||||
effective_chunk_size = (
|
|
||||||
effective_chunk_size // self.page_size
|
|
||||||
) * self.page_size
|
|
||||||
if effective_chunk_size < self.page_size:
|
|
||||||
effective_chunk_size = self.page_size
|
|
||||||
|
|
||||||
# Create local attention metadata
|
|
||||||
(
|
|
||||||
seqlens_q_local_np,
|
|
||||||
cu_seqlens_q_local_np,
|
|
||||||
seqlens_k_local_np,
|
|
||||||
block_table_local,
|
|
||||||
) = make_local_attention_virtual_batches(
|
|
||||||
effective_chunk_size,
|
|
||||||
cu_seqlens_q_np,
|
|
||||||
seq_lens_np,
|
|
||||||
metadata.page_table,
|
|
||||||
self.page_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
|
||||||
local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(
|
|
||||||
device
|
|
||||||
),
|
|
||||||
local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
|
|
||||||
local_block_table=block_table_local,
|
|
||||||
local_max_query_len=seqlens_q_local_np.max(),
|
|
||||||
local_max_seq_len=seqlens_k_local_np.max(),
|
|
||||||
)
|
|
||||||
metadata.local_attn_metadata = local_metadata
|
|
||||||
|
|
||||||
# Encoder metadata for cross attention
|
# Encoder metadata for cross attention
|
||||||
if forward_batch.encoder_lens is not None:
|
if forward_batch.encoder_lens is not None:
|
||||||
@@ -704,6 +677,10 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
# Use precomputed metadata across all layers
|
# Use precomputed metadata across all layers
|
||||||
metadata = self.forward_metadata
|
metadata = self.forward_metadata
|
||||||
|
local_attn_metadata = getattr(metadata, "local_attn_metadata", None)
|
||||||
|
use_local_attention = (
|
||||||
|
self.attention_chunk_size is not None and local_attn_metadata is not None
|
||||||
|
)
|
||||||
|
|
||||||
# Calculate window size (can be moved to metadata if layer properties don't change)
|
# Calculate window size (can be moved to metadata if layer properties don't change)
|
||||||
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
||||||
@@ -738,33 +715,60 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
||||||
)
|
)
|
||||||
|
|
||||||
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
|
||||||
if layer.is_cross_attention:
|
if layer.is_cross_attention:
|
||||||
page_table = metadata.encoder_page_table
|
# Always use non-chunked logic for cross-attention
|
||||||
cache_seqlens = metadata.encoder_lens_int32
|
o = flash_attn_with_kvcache(
|
||||||
cu_seqlens_k = metadata.encoder_cu_seqlens_k
|
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||||
window_size = (-1, -1)
|
k_cache=key_cache,
|
||||||
|
v_cache=value_cache,
|
||||||
|
page_table=metadata.encoder_page_table,
|
||||||
|
cache_seqlens=metadata.encoder_lens_int32,
|
||||||
|
cu_seqlens_q=metadata.cu_seqlens_q,
|
||||||
|
cu_seqlens_k_new=metadata.encoder_cu_seqlens_k,
|
||||||
|
max_seqlen_q=1,
|
||||||
|
softmax_scale=layer.scaling,
|
||||||
|
causal=False,
|
||||||
|
window_size=(-1, -1),
|
||||||
|
softcap=layer.logit_cap,
|
||||||
|
k_descale=k_descale,
|
||||||
|
v_descale=v_descale,
|
||||||
|
)
|
||||||
|
elif use_local_attention:
|
||||||
|
# Use chunked (local) attention batching for self-attention
|
||||||
|
o = flash_attn_with_kvcache(
|
||||||
|
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||||
|
k_cache=key_cache,
|
||||||
|
v_cache=value_cache,
|
||||||
|
page_table=local_attn_metadata.local_block_table,
|
||||||
|
cache_seqlens=local_attn_metadata.local_seqused_k,
|
||||||
|
cu_seqlens_q=local_attn_metadata.local_query_start_loc,
|
||||||
|
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
||||||
|
max_seqlen_q=local_attn_metadata.local_max_query_len,
|
||||||
|
softmax_scale=layer.scaling,
|
||||||
|
causal=True,
|
||||||
|
window_size=(-1, -1),
|
||||||
|
softcap=layer.logit_cap,
|
||||||
|
k_descale=k_descale,
|
||||||
|
v_descale=v_descale,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
page_table = metadata.page_table
|
# Default: single-token self-attention
|
||||||
cache_seqlens = metadata.cache_seqlens_int32
|
o = flash_attn_with_kvcache(
|
||||||
cu_seqlens_k = metadata.cu_seqlens_k
|
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||||
|
k_cache=key_cache,
|
||||||
o = flash_attn_with_kvcache(
|
v_cache=value_cache,
|
||||||
q=q_reshaped,
|
page_table=metadata.page_table,
|
||||||
k_cache=key_cache,
|
cache_seqlens=metadata.cache_seqlens_int32,
|
||||||
v_cache=value_cache,
|
cu_seqlens_q=metadata.cu_seqlens_q,
|
||||||
page_table=page_table,
|
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
||||||
cache_seqlens=cache_seqlens,
|
max_seqlen_q=1,
|
||||||
cu_seqlens_q=metadata.cu_seqlens_q,
|
softmax_scale=layer.scaling,
|
||||||
cu_seqlens_k_new=cu_seqlens_k,
|
causal=True,
|
||||||
max_seqlen_q=1,
|
window_size=window_size,
|
||||||
softmax_scale=layer.scaling,
|
softcap=layer.logit_cap,
|
||||||
causal=causal,
|
k_descale=k_descale,
|
||||||
window_size=window_size,
|
v_descale=v_descale,
|
||||||
softcap=layer.logit_cap,
|
)
|
||||||
k_descale=k_descale,
|
|
||||||
v_descale=v_descale,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# Do absorbed multi-latent attention
|
# Do absorbed multi-latent attention
|
||||||
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||||
@@ -986,6 +990,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
seq_lens = seq_lens[:bs]
|
seq_lens = seq_lens[:bs]
|
||||||
seq_lens_cpu = seq_lens_cpu[:bs]
|
seq_lens_cpu = seq_lens_cpu[:bs]
|
||||||
req_pool_indices = req_pool_indices[:bs]
|
req_pool_indices = req_pool_indices[:bs]
|
||||||
|
device = seq_lens.device
|
||||||
|
|
||||||
if forward_mode.is_decode_or_idle():
|
if forward_mode.is_decode_or_idle():
|
||||||
metadata = self.decode_cuda_graph_metadata[bs]
|
metadata = self.decode_cuda_graph_metadata[bs]
|
||||||
|
|
||||||
@@ -1012,6 +1018,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
]
|
]
|
||||||
|
|
||||||
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
||||||
|
|
||||||
|
self._init_local_attn_metadata(metadata, device)
|
||||||
else:
|
else:
|
||||||
# Normal Decode
|
# Normal Decode
|
||||||
max_len = seq_lens_cpu.max().item()
|
max_len = seq_lens_cpu.max().item()
|
||||||
@@ -1035,6 +1043,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
||||||
metadata.page_table[:, max_seq_pages:].fill_(0)
|
metadata.page_table[:, max_seq_pages:].fill_(0)
|
||||||
|
|
||||||
|
self._init_local_attn_metadata(metadata, device)
|
||||||
elif forward_mode.is_target_verify():
|
elif forward_mode.is_target_verify():
|
||||||
metadata = self.target_verify_metadata[bs]
|
metadata = self.target_verify_metadata[bs]
|
||||||
metadata.cache_seqlens_int32.copy_(
|
metadata.cache_seqlens_int32.copy_(
|
||||||
@@ -1085,6 +1094,42 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
"""Get the fill value for sequence length in CUDA graph."""
|
"""Get the fill value for sequence length in CUDA graph."""
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device):
|
||||||
|
"""Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
|
||||||
|
if self.attention_chunk_size is None:
|
||||||
|
metadata.local_attn_metadata = None
|
||||||
|
return
|
||||||
|
|
||||||
|
cu_seqlens_q = metadata.cu_seqlens_q
|
||||||
|
cache_seqlens_int32 = metadata.cache_seqlens_int32
|
||||||
|
page_table = metadata.page_table
|
||||||
|
if cu_seqlens_q is None or cache_seqlens_int32 is None or page_table is None:
|
||||||
|
metadata.local_attn_metadata = None
|
||||||
|
return
|
||||||
|
|
||||||
|
cu_seqlens_q_np = cu_seqlens_q.cpu().numpy()
|
||||||
|
seq_lens_np = cache_seqlens_int32.cpu().numpy()
|
||||||
|
(
|
||||||
|
seqlens_q_local_np,
|
||||||
|
cu_seqlens_q_local_np,
|
||||||
|
seqlens_k_local_np,
|
||||||
|
block_table_local,
|
||||||
|
) = make_local_attention_virtual_batches(
|
||||||
|
self.attention_chunk_size,
|
||||||
|
cu_seqlens_q_np,
|
||||||
|
seq_lens_np,
|
||||||
|
page_table,
|
||||||
|
self.page_size,
|
||||||
|
)
|
||||||
|
local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
||||||
|
local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(device),
|
||||||
|
local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
|
||||||
|
local_block_table=block_table_local.to(device),
|
||||||
|
local_max_query_len=int(seqlens_q_local_np.max()),
|
||||||
|
local_max_seq_len=int(seqlens_k_local_np.max()),
|
||||||
|
)
|
||||||
|
metadata.local_attn_metadata = local_metadata
|
||||||
|
|
||||||
|
|
||||||
class FlashAttentionMultiStepBackend:
|
class FlashAttentionMultiStepBackend:
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user