[Bugfix] Fix Llama4 gibberish output with long context and CUDA graph (#6162)
This commit is contained in:
@@ -913,8 +913,10 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
# Use precomputed metadata across all layers
|
# Use precomputed metadata across all layers
|
||||||
metadata = self.forward_metadata
|
metadata = self.forward_metadata
|
||||||
local_attn_metadata = getattr(metadata, "local_attn_metadata", None)
|
local_attn_metadata = getattr(metadata, "local_attn_metadata", None)
|
||||||
use_local_attention = (
|
use_local_attn = (
|
||||||
self.attention_chunk_size is not None and local_attn_metadata is not None
|
self.attention_chunk_size is not None
|
||||||
|
and local_attn_metadata is not None
|
||||||
|
and (hasattr(layer, "use_irope") and layer.use_irope)
|
||||||
)
|
)
|
||||||
# We do cascade attention for Draft Decode with topk > 1
|
# We do cascade attention for Draft Decode with topk > 1
|
||||||
use_cascade_attn = self.topk > 1
|
use_cascade_attn = self.topk > 1
|
||||||
@@ -970,7 +972,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
k_descale=k_descale,
|
k_descale=k_descale,
|
||||||
v_descale=v_descale,
|
v_descale=v_descale,
|
||||||
)
|
)
|
||||||
elif use_local_attention:
|
elif use_local_attn:
|
||||||
# Use chunked (local) attention batching for self-attention
|
# Use chunked (local) attention batching for self-attention
|
||||||
o = flash_attn_with_kvcache(
|
o = flash_attn_with_kvcache(
|
||||||
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||||
@@ -979,7 +981,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
page_table=local_attn_metadata.local_block_table,
|
page_table=local_attn_metadata.local_block_table,
|
||||||
cache_seqlens=local_attn_metadata.local_seqused_k,
|
cache_seqlens=local_attn_metadata.local_seqused_k,
|
||||||
cu_seqlens_q=local_attn_metadata.local_query_start_loc,
|
cu_seqlens_q=local_attn_metadata.local_query_start_loc,
|
||||||
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
cu_seqlens_k_new=None,
|
||||||
max_seqlen_q=local_attn_metadata.local_max_query_len,
|
max_seqlen_q=local_attn_metadata.local_max_query_len,
|
||||||
softmax_scale=layer.scaling,
|
softmax_scale=layer.scaling,
|
||||||
causal=True,
|
causal=True,
|
||||||
@@ -1127,7 +1129,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
This creates fixed-size tensors that will be reused during CUDA graph replay
|
This creates fixed-size tensors that will be reused during CUDA graph replay
|
||||||
to avoid memory allocations.
|
to avoid memory allocations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# This is being used by normal decode and draft decode when topk == 1
|
# This is being used by normal decode and draft decode when topk == 1
|
||||||
self.decode_cuda_graph_metadata = {
|
self.decode_cuda_graph_metadata = {
|
||||||
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
||||||
@@ -1154,6 +1155,34 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Only allocate local attention buffers if local attention is enabled
|
||||||
|
# This prevents OOM errors when local attention is not being used
|
||||||
|
if self.attention_chunk_size is not None:
|
||||||
|
# Estimate maximum sizes for local attention metadata
|
||||||
|
max_seq_len = self.max_context_len
|
||||||
|
page_size = self.page_size or 1
|
||||||
|
attn_chunk_size = self.attention_chunk_size
|
||||||
|
max_virtual_batches = max_bs * (
|
||||||
|
(max_seq_len + attn_chunk_size - 1) // attn_chunk_size
|
||||||
|
)
|
||||||
|
max_blocks_per_seq = (max_seq_len + attn_chunk_size - 1) // attn_chunk_size
|
||||||
|
max_pages_per_block = (attn_chunk_size + page_size - 1) // page_size
|
||||||
|
|
||||||
|
self.decode_cuda_graph_local_attn_metadata = {
|
||||||
|
"local_query_start_loc": torch.zeros(
|
||||||
|
max_virtual_batches + 1, dtype=torch.int32, device=self.device
|
||||||
|
),
|
||||||
|
"local_seqused_k": torch.zeros(
|
||||||
|
max_virtual_batches, dtype=torch.int32, device=self.device
|
||||||
|
),
|
||||||
|
"local_block_table": torch.zeros(
|
||||||
|
max_virtual_batches,
|
||||||
|
max_blocks_per_seq * max_pages_per_block,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
# This is used by draft decode's first half of metadata when topk > 1
|
# This is used by draft decode's first half of metadata when topk > 1
|
||||||
if self.topk > 1:
|
if self.topk > 1:
|
||||||
self.draft_decode_metadata_topk_normal = {
|
self.draft_decode_metadata_topk_normal = {
|
||||||
@@ -1405,6 +1434,21 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
self.decode_cuda_graph_metadata[bs] = metadata
|
self.decode_cuda_graph_metadata[bs] = metadata
|
||||||
|
|
||||||
|
if self.attention_chunk_size is not None:
|
||||||
|
metadata.local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
||||||
|
local_query_start_loc=self.decode_cuda_graph_local_attn_metadata[
|
||||||
|
"local_query_start_loc"
|
||||||
|
],
|
||||||
|
local_seqused_k=self.decode_cuda_graph_local_attn_metadata[
|
||||||
|
"local_seqused_k"
|
||||||
|
],
|
||||||
|
local_block_table=self.decode_cuda_graph_local_attn_metadata[
|
||||||
|
"local_block_table"
|
||||||
|
],
|
||||||
|
local_max_query_len=1,
|
||||||
|
local_max_seq_len=1,
|
||||||
|
)
|
||||||
|
|
||||||
elif forward_mode.is_target_verify():
|
elif forward_mode.is_target_verify():
|
||||||
if self.topk <= 1:
|
if self.topk <= 1:
|
||||||
metadata.cache_seqlens_int32 = self.target_verify_metadata[
|
metadata.cache_seqlens_int32 = self.target_verify_metadata[
|
||||||
@@ -1572,8 +1616,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
metadata_expand.page_table[: cache_loc.shape[0]].copy_(
|
metadata_expand.page_table[: cache_loc.shape[0]].copy_(
|
||||||
cache_loc[:, :decode_length].contiguous().to(torch.int32)
|
cache_loc[:, :decode_length].contiguous().to(torch.int32)
|
||||||
)
|
)
|
||||||
# TODO: we need to test this part for llama 4 eagle case
|
# TODO: Handle local attention metadata for draft decode when llama4 eagle is supported
|
||||||
self._init_local_attn_metadata(metadata, device)
|
|
||||||
else:
|
else:
|
||||||
metadata = self.decode_cuda_graph_metadata[bs]
|
metadata = self.decode_cuda_graph_metadata[bs]
|
||||||
# Normal Decode
|
# Normal Decode
|
||||||
@@ -1599,7 +1642,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
||||||
metadata.page_table[:, max_seq_pages:].fill_(0)
|
metadata.page_table[:, max_seq_pages:].fill_(0)
|
||||||
|
|
||||||
self._init_local_attn_metadata(metadata, device)
|
self._update_local_attn_metadata_for_replay(metadata, bs)
|
||||||
elif forward_mode.is_target_verify():
|
elif forward_mode.is_target_verify():
|
||||||
if self.topk <= 1:
|
if self.topk <= 1:
|
||||||
metadata = self.target_verify_metadata[bs]
|
metadata = self.target_verify_metadata[bs]
|
||||||
@@ -1755,6 +1798,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
page_table,
|
page_table,
|
||||||
self.page_size,
|
self.page_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
||||||
local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(device),
|
local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(device),
|
||||||
local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
|
local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
|
||||||
@@ -1764,6 +1808,79 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
metadata.local_attn_metadata = local_metadata
|
metadata.local_attn_metadata = local_metadata
|
||||||
|
|
||||||
|
def _update_local_attn_metadata_for_replay(
|
||||||
|
self, metadata: FlashAttentionMetadata, bs: int
|
||||||
|
):
|
||||||
|
"""Update preallocated local attention metadata in-place before CUDA graph replay."""
|
||||||
|
if self.attention_chunk_size is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Access preallocated buffers
|
||||||
|
local_q_buf = self.decode_cuda_graph_local_attn_metadata[
|
||||||
|
"local_query_start_loc"
|
||||||
|
]
|
||||||
|
local_k_buf = self.decode_cuda_graph_local_attn_metadata["local_seqused_k"]
|
||||||
|
local_block_buf = self.decode_cuda_graph_local_attn_metadata[
|
||||||
|
"local_block_table"
|
||||||
|
]
|
||||||
|
cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"]
|
||||||
|
|
||||||
|
# Create a modified version for local attention that only processes the last token
|
||||||
|
# This mimics the normal decode pattern
|
||||||
|
cu_seqlens_q = torch.arange(
|
||||||
|
bs + 1, device=cu_seqlens_q.device, dtype=cu_seqlens_q.dtype
|
||||||
|
)
|
||||||
|
seqlens = metadata.cache_seqlens_int32[:bs]
|
||||||
|
# Slice the page_table to match the batch size and actual sequence length
|
||||||
|
# This serves three important purposes:
|
||||||
|
# 1. Ensures we only process the actual batch size (bs) and not the maximum batch size
|
||||||
|
# 2. Limits the sequence length to prevent processing padding tokens or garbage values
|
||||||
|
# 3. Prevents zeros in the block table which can cause garbage output during replay
|
||||||
|
#
|
||||||
|
# Without this slicing, the pre-allocated page_table may contain zeros or invalid indices
|
||||||
|
# beyond the actual sequence length, leading to incorrect attention calculations
|
||||||
|
max_seq_len = int(seqlens.max().item())
|
||||||
|
sliced_page_table = metadata.page_table[:bs, :max_seq_len]
|
||||||
|
|
||||||
|
cu_seqlens_q_np = cu_seqlens_q.cpu().numpy()
|
||||||
|
seqlens_np = seqlens.cpu().numpy()
|
||||||
|
(
|
||||||
|
seqlens_q_local_np,
|
||||||
|
cu_seqlens_q_local_np,
|
||||||
|
seqlens_k_local_np,
|
||||||
|
block_table_local,
|
||||||
|
) = make_local_attention_virtual_batches(
|
||||||
|
self.attention_chunk_size,
|
||||||
|
cu_seqlens_q_np,
|
||||||
|
seqlens_np,
|
||||||
|
sliced_page_table,
|
||||||
|
self.page_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert back to tensors
|
||||||
|
device = local_q_buf.device
|
||||||
|
cu_seqlens_q_local = torch.from_numpy(cu_seqlens_q_local_np).to(device)
|
||||||
|
seqlens_k_local = torch.from_numpy(seqlens_k_local_np).to(device)
|
||||||
|
block_table_local = block_table_local.to(device)
|
||||||
|
# Get sizes
|
||||||
|
q_len = cu_seqlens_q_local.shape[0]
|
||||||
|
k_len = seqlens_k_local.shape[0]
|
||||||
|
b0, b1 = block_table_local.shape
|
||||||
|
|
||||||
|
# In-place updates into preallocated tensors and zero out the unused space
|
||||||
|
local_q_buf[:q_len].copy_(cu_seqlens_q_local)
|
||||||
|
local_q_buf[q_len:].fill_(0)
|
||||||
|
local_k_buf[:k_len].copy_(seqlens_k_local)
|
||||||
|
local_k_buf[k_len:].fill_(0)
|
||||||
|
local_block_buf[:b0, :b1].copy_(block_table_local)
|
||||||
|
local_block_buf[b0:, :].fill_(0)
|
||||||
|
local_block_buf[:b0, b1:].fill_(0)
|
||||||
|
|
||||||
|
if metadata.local_attn_metadata is not None:
|
||||||
|
lam = metadata.local_attn_metadata
|
||||||
|
lam.local_max_query_len = int(seqlens_q_local_np.max())
|
||||||
|
lam.local_max_seq_len = int(seqlens_k_local_np.max())
|
||||||
|
|
||||||
|
|
||||||
class FlashAttentionMultiStepBackend:
|
class FlashAttentionMultiStepBackend:
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user