longcontext chunk make attention crash, fix it (#117)

Co-authored-by: root <root@rdtest-node1150.bcc-zwlt.baidu.com>
This commit is contained in:
baoqian426
2026-01-17 18:38:23 +08:00
committed by GitHub
parent 71a5a04e0a
commit 2512259944
3 changed files with 37 additions and 6 deletions

View File

@@ -17,7 +17,8 @@ def _custom_import(module_name, globals=None, locals=None, fromlist=(), level=0)
"vllm.v1.sample.ops.topk_topp_sampler": "vllm_kunlun.v1.sample.ops.topk_topp_sampler", "vllm.v1.sample.ops.topk_topp_sampler": "vllm_kunlun.v1.sample.ops.topk_topp_sampler",
"vllm.model_executor.layers.sampler": "vllm_kunlun.ops.sample.sampler", "vllm.model_executor.layers.sampler": "vllm_kunlun.ops.sample.sampler",
"vllm.v1.sample.ops.topk_topp_sampler": "vllm_kunlun.v1.sample.ops.topk_topp_sampler", "vllm.v1.sample.ops.topk_topp_sampler": "vllm_kunlun.v1.sample.ops.topk_topp_sampler",
"vllm.v1.sample.rejection_sampler": "vllm_kunlun.v1.sample.rejection_sampler" "vllm.v1.sample.rejection_sampler": "vllm_kunlun.v1.sample.rejection_sampler",
"vllm.attention.ops.merge_attn_states": "vllm_kunlun.ops.attention.merge_attn_states"
} }
if module_name in module_mappings: if module_name in module_mappings:

View File

@@ -0,0 +1,26 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
import xtorch_ops
from vllm.platforms import current_platform
def merge_attn_states(
output: torch.Tensor,
prefix_output: torch.Tensor,
prefix_lse: torch.Tensor,
suffix_output: torch.Tensor,
suffix_lse: torch.Tensor,
output_lse: Optional[torch.Tensor] = None,
) -> None:
return xtorch_ops.attention_merge_stage(
prefix_output,
prefix_lse,
suffix_output,
suffix_lse,
output,
output_lse
)

View File

@@ -207,7 +207,6 @@ from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.config import VllmConfig, get_current_vllm_config from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
from vllm.distributed import get_tp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, LinearBase,
@@ -322,6 +321,7 @@ class MLACommonPrefillMetadata:
# New for MLA (compared to FlashAttention) # New for MLA (compared to FlashAttention)
# For handling chunked prefill # For handling chunked prefill
cu_seq_lens: torch.Tensor cu_seq_lens: torch.Tensor
cu_seq_lens_cpu: torch.Tensor
starts: torch.Tensor starts: torch.Tensor
seq_tot: list[int] seq_tot: list[int]
max_seq_lens: list[int] max_seq_lens: list[int]
@@ -795,6 +795,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
chunked_context_metadata_cls( chunked_context_metadata_cls(
cu_seq_lens=cu_seq_lens_cpu \ cu_seq_lens=cu_seq_lens_cpu \
.to(device, non_blocking=True), .to(device, non_blocking=True),
cu_seq_lens_cpu=cu_seq_lens_cpu,
starts=cp_chunk_starts.to(device, non_blocking=True), starts=cp_chunk_starts.to(device, non_blocking=True),
seq_tot=cp_chunk_seq_lens.sum(dim=1).tolist(), seq_tot=cp_chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
@@ -812,6 +813,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
chunked_context_metadata_cls( chunked_context_metadata_cls(
cu_seq_lens=cu_seq_lens_cpu \ cu_seq_lens=cu_seq_lens_cpu \
.to(device, non_blocking=True), .to(device, non_blocking=True),
cu_seq_lens_cpu=cu_seq_lens_cpu,
starts=chunk_starts.to(device, non_blocking=True), starts=chunk_starts.to(device, non_blocking=True),
seq_tot=chunk_seq_lens.sum(dim=1).tolist(), seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
@@ -1215,7 +1217,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
# ) # )
attn_out = torch.empty_like(q) attn_out = torch.empty_like(q)
ds_alpha = 1.8738542070926265 ds_alpha = 1.8738542070926265
tp_q_head_num=128 tp_q_head_num=q.size(1)
softmax_lse = torch.zeros(tp_q_head_num, q.size(0), dtype=torch.float32, device=q.device) softmax_lse = torch.zeros(tp_q_head_num, q.size(0), dtype=torch.float32, device=q.device)
softmax_lse.fill_(float('-inf')) softmax_lse.fill_(float('-inf'))
xtorch_ops.attention( xtorch_ops.attention(
@@ -1247,7 +1249,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
# Remain consistent with old `flash_attn_varlen_func` where there # Remain consistent with old `flash_attn_varlen_func` where there
# is only one output tensor if `return_softmax_lse` is False. # is only one output tensor if `return_softmax_lse` is False.
if return_softmax_lse: if return_softmax_lse:
return attn_out, lse return attn_out, softmax_lse
return attn_out return attn_out
def _run_prefill_new_tokens_fa(self, prefill: MLACommonPrefillMetadata, q, def _run_prefill_new_tokens_fa(self, prefill: MLACommonPrefillMetadata, q,
@@ -1310,7 +1312,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
k=k, k=k,
v=v, v=v,
context_seq_lod_xpu=prefill.query_start_loc, context_seq_lod_xpu=prefill.query_start_loc,
context_seq_lod_cpu=prefill.chunked_context.cu_seq_lens[chunk_idx], context_seq_lod_cpu=prefill.chunked_context.cu_seq_lens_cpu[chunk_idx],
softmax_scale=self.scale, softmax_scale=self.scale,
causal=False, causal=False,
return_softmax_lse=True, return_softmax_lse=True,
@@ -1548,7 +1550,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
for i in range(iters): for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i] toks = prefill_metadata.chunked_context.seq_tot[i]
self.gather_and_maybe_dequant_cache_py_optimized( torch.ops.xspeedgate_ops.gather_and_maybe_dequant_cache(
src_cache=kv_c_and_k_pe_cache, src_cache=kv_c_and_k_pe_cache,
dst=workspace, dst=workspace,
block_table=prefill_metadata.block_table, block_table=prefill_metadata.block_table,
@@ -1743,12 +1745,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
q, kv_c_and_k_pe_cache, attn_metadata, k_scale) q, kv_c_and_k_pe_cache, attn_metadata, k_scale)
output = torch.empty_like(suffix_output) output = torch.empty_like(suffix_output)
output_lse = torch.empty_like(output)
merge_attn_states( merge_attn_states(
output=output, output=output,
prefix_output=context_output, prefix_output=context_output,
prefix_lse=context_lse, prefix_lse=context_lse,
suffix_output=suffix_output, suffix_output=suffix_output,
suffix_lse=suffix_lse, suffix_lse=suffix_lse,
output_lse=output_lse,
) )
# unpad if necessary # unpad if necessary