longcontext chunk make attention crash, fix it (#117)
Co-authored-by: root <root@rdtest-node1150.bcc-zwlt.baidu.com>
This commit is contained in:
@@ -17,7 +17,8 @@ def _custom_import(module_name, globals=None, locals=None, fromlist=(), level=0)
|
|||||||
"vllm.v1.sample.ops.topk_topp_sampler": "vllm_kunlun.v1.sample.ops.topk_topp_sampler",
|
"vllm.v1.sample.ops.topk_topp_sampler": "vllm_kunlun.v1.sample.ops.topk_topp_sampler",
|
||||||
"vllm.model_executor.layers.sampler": "vllm_kunlun.ops.sample.sampler",
|
"vllm.model_executor.layers.sampler": "vllm_kunlun.ops.sample.sampler",
|
||||||
"vllm.v1.sample.ops.topk_topp_sampler": "vllm_kunlun.v1.sample.ops.topk_topp_sampler",
|
"vllm.v1.sample.ops.topk_topp_sampler": "vllm_kunlun.v1.sample.ops.topk_topp_sampler",
|
||||||
"vllm.v1.sample.rejection_sampler": "vllm_kunlun.v1.sample.rejection_sampler"
|
"vllm.v1.sample.rejection_sampler": "vllm_kunlun.v1.sample.rejection_sampler",
|
||||||
|
"vllm.attention.ops.merge_attn_states": "vllm_kunlun.ops.attention.merge_attn_states"
|
||||||
}
|
}
|
||||||
|
|
||||||
if module_name in module_mappings:
|
if module_name in module_mappings:
|
||||||
|
|||||||
26
vllm_kunlun/ops/attention/merge_attn_states.py
Normal file
26
vllm_kunlun/ops/attention/merge_attn_states.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import xtorch_ops
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
|
def merge_attn_states(
|
||||||
|
output: torch.Tensor,
|
||||||
|
prefix_output: torch.Tensor,
|
||||||
|
prefix_lse: torch.Tensor,
|
||||||
|
suffix_output: torch.Tensor,
|
||||||
|
suffix_lse: torch.Tensor,
|
||||||
|
output_lse: Optional[torch.Tensor] = None,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
return xtorch_ops.attention_merge_stage(
|
||||||
|
prefix_output,
|
||||||
|
prefix_lse,
|
||||||
|
suffix_output,
|
||||||
|
suffix_lse,
|
||||||
|
output,
|
||||||
|
output_lse
|
||||||
|
)
|
||||||
@@ -207,7 +207,6 @@ from vllm.attention.ops.merge_attn_states import merge_attn_states
|
|||||||
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||||
from vllm.config import VllmConfig, get_current_vllm_config
|
from vllm.config import VllmConfig, get_current_vllm_config
|
||||||
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
|
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
|
||||||
from vllm.distributed import get_tp_group
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
LinearBase,
|
LinearBase,
|
||||||
@@ -322,6 +321,7 @@ class MLACommonPrefillMetadata:
|
|||||||
# New for MLA (compared to FlashAttention)
|
# New for MLA (compared to FlashAttention)
|
||||||
# For handling chunked prefill
|
# For handling chunked prefill
|
||||||
cu_seq_lens: torch.Tensor
|
cu_seq_lens: torch.Tensor
|
||||||
|
cu_seq_lens_cpu: torch.Tensor
|
||||||
starts: torch.Tensor
|
starts: torch.Tensor
|
||||||
seq_tot: list[int]
|
seq_tot: list[int]
|
||||||
max_seq_lens: list[int]
|
max_seq_lens: list[int]
|
||||||
@@ -795,6 +795,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
chunked_context_metadata_cls(
|
chunked_context_metadata_cls(
|
||||||
cu_seq_lens=cu_seq_lens_cpu \
|
cu_seq_lens=cu_seq_lens_cpu \
|
||||||
.to(device, non_blocking=True),
|
.to(device, non_blocking=True),
|
||||||
|
cu_seq_lens_cpu=cu_seq_lens_cpu,
|
||||||
starts=cp_chunk_starts.to(device, non_blocking=True),
|
starts=cp_chunk_starts.to(device, non_blocking=True),
|
||||||
seq_tot=cp_chunk_seq_lens.sum(dim=1).tolist(),
|
seq_tot=cp_chunk_seq_lens.sum(dim=1).tolist(),
|
||||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||||
@@ -812,6 +813,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
chunked_context_metadata_cls(
|
chunked_context_metadata_cls(
|
||||||
cu_seq_lens=cu_seq_lens_cpu \
|
cu_seq_lens=cu_seq_lens_cpu \
|
||||||
.to(device, non_blocking=True),
|
.to(device, non_blocking=True),
|
||||||
|
cu_seq_lens_cpu=cu_seq_lens_cpu,
|
||||||
starts=chunk_starts.to(device, non_blocking=True),
|
starts=chunk_starts.to(device, non_blocking=True),
|
||||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||||
@@ -1215,7 +1217,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
# )
|
# )
|
||||||
attn_out = torch.empty_like(q)
|
attn_out = torch.empty_like(q)
|
||||||
ds_alpha = 1.8738542070926265
|
ds_alpha = 1.8738542070926265
|
||||||
tp_q_head_num=128
|
tp_q_head_num=q.size(1)
|
||||||
softmax_lse = torch.zeros(tp_q_head_num, q.size(0), dtype=torch.float32, device=q.device)
|
softmax_lse = torch.zeros(tp_q_head_num, q.size(0), dtype=torch.float32, device=q.device)
|
||||||
softmax_lse.fill_(float('-inf'))
|
softmax_lse.fill_(float('-inf'))
|
||||||
xtorch_ops.attention(
|
xtorch_ops.attention(
|
||||||
@@ -1247,7 +1249,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
# Remain consistent with old `flash_attn_varlen_func` where there
|
# Remain consistent with old `flash_attn_varlen_func` where there
|
||||||
# is only one output tensor if `return_softmax_lse` is False.
|
# is only one output tensor if `return_softmax_lse` is False.
|
||||||
if return_softmax_lse:
|
if return_softmax_lse:
|
||||||
return attn_out, lse
|
return attn_out, softmax_lse
|
||||||
return attn_out
|
return attn_out
|
||||||
|
|
||||||
def _run_prefill_new_tokens_fa(self, prefill: MLACommonPrefillMetadata, q,
|
def _run_prefill_new_tokens_fa(self, prefill: MLACommonPrefillMetadata, q,
|
||||||
@@ -1310,7 +1312,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
k=k,
|
k=k,
|
||||||
v=v,
|
v=v,
|
||||||
context_seq_lod_xpu=prefill.query_start_loc,
|
context_seq_lod_xpu=prefill.query_start_loc,
|
||||||
context_seq_lod_cpu=prefill.chunked_context.cu_seq_lens[chunk_idx],
|
context_seq_lod_cpu=prefill.chunked_context.cu_seq_lens_cpu[chunk_idx],
|
||||||
softmax_scale=self.scale,
|
softmax_scale=self.scale,
|
||||||
causal=False,
|
causal=False,
|
||||||
return_softmax_lse=True,
|
return_softmax_lse=True,
|
||||||
@@ -1548,7 +1550,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
|
|
||||||
for i in range(iters):
|
for i in range(iters):
|
||||||
toks = prefill_metadata.chunked_context.seq_tot[i]
|
toks = prefill_metadata.chunked_context.seq_tot[i]
|
||||||
self.gather_and_maybe_dequant_cache_py_optimized(
|
torch.ops.xspeedgate_ops.gather_and_maybe_dequant_cache(
|
||||||
src_cache=kv_c_and_k_pe_cache,
|
src_cache=kv_c_and_k_pe_cache,
|
||||||
dst=workspace,
|
dst=workspace,
|
||||||
block_table=prefill_metadata.block_table,
|
block_table=prefill_metadata.block_table,
|
||||||
@@ -1743,12 +1745,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
q, kv_c_and_k_pe_cache, attn_metadata, k_scale)
|
q, kv_c_and_k_pe_cache, attn_metadata, k_scale)
|
||||||
|
|
||||||
output = torch.empty_like(suffix_output)
|
output = torch.empty_like(suffix_output)
|
||||||
|
output_lse = torch.empty_like(output)
|
||||||
merge_attn_states(
|
merge_attn_states(
|
||||||
output=output,
|
output=output,
|
||||||
prefix_output=context_output,
|
prefix_output=context_output,
|
||||||
prefix_lse=context_lse,
|
prefix_lse=context_lse,
|
||||||
suffix_output=suffix_output,
|
suffix_output=suffix_output,
|
||||||
suffix_lse=suffix_lse,
|
suffix_lse=suffix_lse,
|
||||||
|
output_lse=output_lse,
|
||||||
)
|
)
|
||||||
|
|
||||||
# unpad if necessary
|
# unpad if necessary
|
||||||
|
|||||||
Reference in New Issue
Block a user