longcontext chunk make attention crash, fix it (#117)
Co-authored-by: root <root@rdtest-node1150.bcc-zwlt.baidu.com>
This commit is contained in:
@@ -17,7 +17,8 @@ def _custom_import(module_name, globals=None, locals=None, fromlist=(), level=0)
|
||||
"vllm.v1.sample.ops.topk_topp_sampler": "vllm_kunlun.v1.sample.ops.topk_topp_sampler",
|
||||
"vllm.model_executor.layers.sampler": "vllm_kunlun.ops.sample.sampler",
|
||||
"vllm.v1.sample.ops.topk_topp_sampler": "vllm_kunlun.v1.sample.ops.topk_topp_sampler",
|
||||
"vllm.v1.sample.rejection_sampler": "vllm_kunlun.v1.sample.rejection_sampler"
|
||||
"vllm.v1.sample.rejection_sampler": "vllm_kunlun.v1.sample.rejection_sampler",
|
||||
"vllm.attention.ops.merge_attn_states": "vllm_kunlun.ops.attention.merge_attn_states"
|
||||
}
|
||||
|
||||
if module_name in module_mappings:
|
||||
|
||||
26
vllm_kunlun/ops/attention/merge_attn_states.py
Normal file
26
vllm_kunlun/ops/attention/merge_attn_states.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import xtorch_ops
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def merge_attn_states(
|
||||
output: torch.Tensor,
|
||||
prefix_output: torch.Tensor,
|
||||
prefix_lse: torch.Tensor,
|
||||
suffix_output: torch.Tensor,
|
||||
suffix_lse: torch.Tensor,
|
||||
output_lse: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
|
||||
return xtorch_ops.attention_merge_stage(
|
||||
prefix_output,
|
||||
prefix_lse,
|
||||
suffix_output,
|
||||
suffix_lse,
|
||||
output,
|
||||
output_lse
|
||||
)
|
||||
@@ -207,7 +207,6 @@ from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
|
||||
from vllm.distributed import get_tp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase,
|
||||
@@ -322,6 +321,7 @@ class MLACommonPrefillMetadata:
|
||||
# New for MLA (compared to FlashAttention)
|
||||
# For handling chunked prefill
|
||||
cu_seq_lens: torch.Tensor
|
||||
cu_seq_lens_cpu: torch.Tensor
|
||||
starts: torch.Tensor
|
||||
seq_tot: list[int]
|
||||
max_seq_lens: list[int]
|
||||
@@ -795,6 +795,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
chunked_context_metadata_cls(
|
||||
cu_seq_lens=cu_seq_lens_cpu \
|
||||
.to(device, non_blocking=True),
|
||||
cu_seq_lens_cpu=cu_seq_lens_cpu,
|
||||
starts=cp_chunk_starts.to(device, non_blocking=True),
|
||||
seq_tot=cp_chunk_seq_lens.sum(dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
@@ -812,6 +813,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
chunked_context_metadata_cls(
|
||||
cu_seq_lens=cu_seq_lens_cpu \
|
||||
.to(device, non_blocking=True),
|
||||
cu_seq_lens_cpu=cu_seq_lens_cpu,
|
||||
starts=chunk_starts.to(device, non_blocking=True),
|
||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
@@ -1215,7 +1217,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
# )
|
||||
attn_out = torch.empty_like(q)
|
||||
ds_alpha = 1.8738542070926265
|
||||
tp_q_head_num=128
|
||||
tp_q_head_num=q.size(1)
|
||||
softmax_lse = torch.zeros(tp_q_head_num, q.size(0), dtype=torch.float32, device=q.device)
|
||||
softmax_lse.fill_(float('-inf'))
|
||||
xtorch_ops.attention(
|
||||
@@ -1247,7 +1249,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
# Remain consistent with old `flash_attn_varlen_func` where there
|
||||
# is only one output tensor if `return_softmax_lse` is False.
|
||||
if return_softmax_lse:
|
||||
return attn_out, lse
|
||||
return attn_out, softmax_lse
|
||||
return attn_out
|
||||
|
||||
def _run_prefill_new_tokens_fa(self, prefill: MLACommonPrefillMetadata, q,
|
||||
@@ -1310,7 +1312,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
k=k,
|
||||
v=v,
|
||||
context_seq_lod_xpu=prefill.query_start_loc,
|
||||
context_seq_lod_cpu=prefill.chunked_context.cu_seq_lens[chunk_idx],
|
||||
context_seq_lod_cpu=prefill.chunked_context.cu_seq_lens_cpu[chunk_idx],
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
return_softmax_lse=True,
|
||||
@@ -1548,7 +1550,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
|
||||
for i in range(iters):
|
||||
toks = prefill_metadata.chunked_context.seq_tot[i]
|
||||
self.gather_and_maybe_dequant_cache_py_optimized(
|
||||
torch.ops.xspeedgate_ops.gather_and_maybe_dequant_cache(
|
||||
src_cache=kv_c_and_k_pe_cache,
|
||||
dst=workspace,
|
||||
block_table=prefill_metadata.block_table,
|
||||
@@ -1743,12 +1745,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
q, kv_c_and_k_pe_cache, attn_metadata, k_scale)
|
||||
|
||||
output = torch.empty_like(suffix_output)
|
||||
output_lse = torch.empty_like(output)
|
||||
merge_attn_states(
|
||||
output=output,
|
||||
prefix_output=context_output,
|
||||
prefix_lse=context_lse,
|
||||
suffix_output=suffix_output,
|
||||
suffix_lse=suffix_lse,
|
||||
output_lse=output_lse,
|
||||
)
|
||||
|
||||
# unpad if necessary
|
||||
|
||||
Reference in New Issue
Block a user