[Bugfix] Fix Llama4 gibberish output with long context and CUDA graph (#6162)
This commit is contained in:
@@ -913,8 +913,10 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
# Use precomputed metadata across all layers
|
||||
metadata = self.forward_metadata
|
||||
local_attn_metadata = getattr(metadata, "local_attn_metadata", None)
|
||||
use_local_attention = (
|
||||
self.attention_chunk_size is not None and local_attn_metadata is not None
|
||||
use_local_attn = (
|
||||
self.attention_chunk_size is not None
|
||||
and local_attn_metadata is not None
|
||||
and (hasattr(layer, "use_irope") and layer.use_irope)
|
||||
)
|
||||
# We do cascade attention for Draft Decode with topk > 1
|
||||
use_cascade_attn = self.topk > 1
|
||||
@@ -970,7 +972,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
)
|
||||
elif use_local_attention:
|
||||
elif use_local_attn:
|
||||
# Use chunked (local) attention batching for self-attention
|
||||
o = flash_attn_with_kvcache(
|
||||
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
@@ -979,7 +981,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
page_table=local_attn_metadata.local_block_table,
|
||||
cache_seqlens=local_attn_metadata.local_seqused_k,
|
||||
cu_seqlens_q=local_attn_metadata.local_query_start_loc,
|
||||
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
||||
cu_seqlens_k_new=None,
|
||||
max_seqlen_q=local_attn_metadata.local_max_query_len,
|
||||
softmax_scale=layer.scaling,
|
||||
causal=True,
|
||||
@@ -1127,7 +1129,6 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
This creates fixed-size tensors that will be reused during CUDA graph replay
|
||||
to avoid memory allocations.
|
||||
"""
|
||||
|
||||
# This is being used by normal decode and draft decode when topk == 1
|
||||
self.decode_cuda_graph_metadata = {
|
||||
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
||||
@@ -1154,6 +1155,34 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
),
|
||||
}
|
||||
|
||||
# Only allocate local attention buffers if local attention is enabled
|
||||
# This prevents OOM errors when local attention is not being used
|
||||
if self.attention_chunk_size is not None:
|
||||
# Estimate maximum sizes for local attention metadata
|
||||
max_seq_len = self.max_context_len
|
||||
page_size = self.page_size or 1
|
||||
attn_chunk_size = self.attention_chunk_size
|
||||
max_virtual_batches = max_bs * (
|
||||
(max_seq_len + attn_chunk_size - 1) // attn_chunk_size
|
||||
)
|
||||
max_blocks_per_seq = (max_seq_len + attn_chunk_size - 1) // attn_chunk_size
|
||||
max_pages_per_block = (attn_chunk_size + page_size - 1) // page_size
|
||||
|
||||
self.decode_cuda_graph_local_attn_metadata = {
|
||||
"local_query_start_loc": torch.zeros(
|
||||
max_virtual_batches + 1, dtype=torch.int32, device=self.device
|
||||
),
|
||||
"local_seqused_k": torch.zeros(
|
||||
max_virtual_batches, dtype=torch.int32, device=self.device
|
||||
),
|
||||
"local_block_table": torch.zeros(
|
||||
max_virtual_batches,
|
||||
max_blocks_per_seq * max_pages_per_block,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
}
|
||||
|
||||
# This is used by draft decode's first half of metadata when topk > 1
|
||||
if self.topk > 1:
|
||||
self.draft_decode_metadata_topk_normal = {
|
||||
@@ -1405,6 +1434,21 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
)
|
||||
self.decode_cuda_graph_metadata[bs] = metadata
|
||||
|
||||
if self.attention_chunk_size is not None:
|
||||
metadata.local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
||||
local_query_start_loc=self.decode_cuda_graph_local_attn_metadata[
|
||||
"local_query_start_loc"
|
||||
],
|
||||
local_seqused_k=self.decode_cuda_graph_local_attn_metadata[
|
||||
"local_seqused_k"
|
||||
],
|
||||
local_block_table=self.decode_cuda_graph_local_attn_metadata[
|
||||
"local_block_table"
|
||||
],
|
||||
local_max_query_len=1,
|
||||
local_max_seq_len=1,
|
||||
)
|
||||
|
||||
elif forward_mode.is_target_verify():
|
||||
if self.topk <= 1:
|
||||
metadata.cache_seqlens_int32 = self.target_verify_metadata[
|
||||
@@ -1572,8 +1616,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
metadata_expand.page_table[: cache_loc.shape[0]].copy_(
|
||||
cache_loc[:, :decode_length].contiguous().to(torch.int32)
|
||||
)
|
||||
# TODO: we need to test this part for llama 4 eagle case
|
||||
self._init_local_attn_metadata(metadata, device)
|
||||
# TODO: Handle local attention metadata for draft decode when llama4 eagle is supported
|
||||
else:
|
||||
metadata = self.decode_cuda_graph_metadata[bs]
|
||||
# Normal Decode
|
||||
@@ -1599,7 +1642,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
||||
metadata.page_table[:, max_seq_pages:].fill_(0)
|
||||
|
||||
self._init_local_attn_metadata(metadata, device)
|
||||
self._update_local_attn_metadata_for_replay(metadata, bs)
|
||||
elif forward_mode.is_target_verify():
|
||||
if self.topk <= 1:
|
||||
metadata = self.target_verify_metadata[bs]
|
||||
@@ -1755,6 +1798,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
page_table,
|
||||
self.page_size,
|
||||
)
|
||||
|
||||
local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
||||
local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(device),
|
||||
local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
|
||||
@@ -1764,6 +1808,79 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
)
|
||||
metadata.local_attn_metadata = local_metadata
|
||||
|
||||
def _update_local_attn_metadata_for_replay(
|
||||
self, metadata: FlashAttentionMetadata, bs: int
|
||||
):
|
||||
"""Update preallocated local attention metadata in-place before CUDA graph replay."""
|
||||
if self.attention_chunk_size is None:
|
||||
return
|
||||
|
||||
# Access preallocated buffers
|
||||
local_q_buf = self.decode_cuda_graph_local_attn_metadata[
|
||||
"local_query_start_loc"
|
||||
]
|
||||
local_k_buf = self.decode_cuda_graph_local_attn_metadata["local_seqused_k"]
|
||||
local_block_buf = self.decode_cuda_graph_local_attn_metadata[
|
||||
"local_block_table"
|
||||
]
|
||||
cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"]
|
||||
|
||||
# Create a modified version for local attention that only processes the last token
|
||||
# This mimics the normal decode pattern
|
||||
cu_seqlens_q = torch.arange(
|
||||
bs + 1, device=cu_seqlens_q.device, dtype=cu_seqlens_q.dtype
|
||||
)
|
||||
seqlens = metadata.cache_seqlens_int32[:bs]
|
||||
# Slice the page_table to match the batch size and actual sequence length
|
||||
# This serves three important purposes:
|
||||
# 1. Ensures we only process the actual batch size (bs) and not the maximum batch size
|
||||
# 2. Limits the sequence length to prevent processing padding tokens or garbage values
|
||||
# 3. Prevents zeros in the block table which can cause garbage output during replay
|
||||
#
|
||||
# Without this slicing, the pre-allocated page_table may contain zeros or invalid indices
|
||||
# beyond the actual sequence length, leading to incorrect attention calculations
|
||||
max_seq_len = int(seqlens.max().item())
|
||||
sliced_page_table = metadata.page_table[:bs, :max_seq_len]
|
||||
|
||||
cu_seqlens_q_np = cu_seqlens_q.cpu().numpy()
|
||||
seqlens_np = seqlens.cpu().numpy()
|
||||
(
|
||||
seqlens_q_local_np,
|
||||
cu_seqlens_q_local_np,
|
||||
seqlens_k_local_np,
|
||||
block_table_local,
|
||||
) = make_local_attention_virtual_batches(
|
||||
self.attention_chunk_size,
|
||||
cu_seqlens_q_np,
|
||||
seqlens_np,
|
||||
sliced_page_table,
|
||||
self.page_size,
|
||||
)
|
||||
|
||||
# Convert back to tensors
|
||||
device = local_q_buf.device
|
||||
cu_seqlens_q_local = torch.from_numpy(cu_seqlens_q_local_np).to(device)
|
||||
seqlens_k_local = torch.from_numpy(seqlens_k_local_np).to(device)
|
||||
block_table_local = block_table_local.to(device)
|
||||
# Get sizes
|
||||
q_len = cu_seqlens_q_local.shape[0]
|
||||
k_len = seqlens_k_local.shape[0]
|
||||
b0, b1 = block_table_local.shape
|
||||
|
||||
# In-place updates into preallocated tensors and zero out the unused space
|
||||
local_q_buf[:q_len].copy_(cu_seqlens_q_local)
|
||||
local_q_buf[q_len:].fill_(0)
|
||||
local_k_buf[:k_len].copy_(seqlens_k_local)
|
||||
local_k_buf[k_len:].fill_(0)
|
||||
local_block_buf[:b0, :b1].copy_(block_table_local)
|
||||
local_block_buf[b0:, :].fill_(0)
|
||||
local_block_buf[:b0, b1:].fill_(0)
|
||||
|
||||
if metadata.local_attn_metadata is not None:
|
||||
lam = metadata.local_attn_metadata
|
||||
lam.local_max_query_len = int(seqlens_q_local_np.max())
|
||||
lam.local_max_seq_len = int(seqlens_k_local_np.max())
|
||||
|
||||
|
||||
class FlashAttentionMultiStepBackend:
|
||||
|
||||
|
||||
Reference in New Issue
Block a user