longcontext chunk make attention crash, fix it (#117)

Co-authored-by: root <root@rdtest-node1150.bcc-zwlt.baidu.com>
This commit is contained in:
baoqian426
2026-01-17 18:38:23 +08:00
committed by GitHub
parent 71a5a04e0a
commit 2512259944
3 changed files with 37 additions and 6 deletions

View File

@@ -207,7 +207,6 @@ from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
from vllm.distributed import get_tp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase,
@@ -322,6 +321,7 @@ class MLACommonPrefillMetadata:
# New for MLA (compared to FlashAttention)
# For handling chunked prefill
cu_seq_lens: torch.Tensor
cu_seq_lens_cpu: torch.Tensor
starts: torch.Tensor
seq_tot: list[int]
max_seq_lens: list[int]
@@ -795,6 +795,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
chunked_context_metadata_cls(
cu_seq_lens=cu_seq_lens_cpu \
.to(device, non_blocking=True),
cu_seq_lens_cpu=cu_seq_lens_cpu,
starts=cp_chunk_starts.to(device, non_blocking=True),
seq_tot=cp_chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
@@ -812,6 +813,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
chunked_context_metadata_cls(
cu_seq_lens=cu_seq_lens_cpu \
.to(device, non_blocking=True),
cu_seq_lens_cpu=cu_seq_lens_cpu,
starts=chunk_starts.to(device, non_blocking=True),
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
@@ -1215,7 +1217,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
# )
attn_out = torch.empty_like(q)
ds_alpha = 1.8738542070926265
tp_q_head_num=128
tp_q_head_num=q.size(1)
softmax_lse = torch.zeros(tp_q_head_num, q.size(0), dtype=torch.float32, device=q.device)
softmax_lse.fill_(float('-inf'))
xtorch_ops.attention(
@@ -1247,7 +1249,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
# Remain consistent with old `flash_attn_varlen_func` where there
# is only one output tensor if `return_softmax_lse` is False.
if return_softmax_lse:
return attn_out, lse
return attn_out, softmax_lse
return attn_out
def _run_prefill_new_tokens_fa(self, prefill: MLACommonPrefillMetadata, q,
@@ -1310,7 +1312,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
k=k,
v=v,
context_seq_lod_xpu=prefill.query_start_loc,
context_seq_lod_cpu=prefill.chunked_context.cu_seq_lens[chunk_idx],
context_seq_lod_cpu=prefill.chunked_context.cu_seq_lens_cpu[chunk_idx],
softmax_scale=self.scale,
causal=False,
return_softmax_lse=True,
@@ -1548,7 +1550,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]
self.gather_and_maybe_dequant_cache_py_optimized(
torch.ops.xspeedgate_ops.gather_and_maybe_dequant_cache(
src_cache=kv_c_and_k_pe_cache,
dst=workspace,
block_table=prefill_metadata.block_table,
@@ -1743,12 +1745,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
q, kv_c_and_k_pe_cache, attn_metadata, k_scale)
output = torch.empty_like(suffix_output)
output_lse = torch.empty_like(output)
merge_attn_states(
output=output,
prefix_output=context_output,
prefix_lse=context_lse,
suffix_output=suffix_output,
suffix_lse=suffix_lse,
output_lse=output_lse,
)
# unpad if necessary