diff --git a/vllm_kunlun/__init__.py b/vllm_kunlun/__init__.py index 4a22dff..97b9d7c 100644 --- a/vllm_kunlun/__init__.py +++ b/vllm_kunlun/__init__.py @@ -17,7 +17,8 @@ def _custom_import(module_name, globals=None, locals=None, fromlist=(), level=0) "vllm.v1.sample.ops.topk_topp_sampler": "vllm_kunlun.v1.sample.ops.topk_topp_sampler", "vllm.model_executor.layers.sampler": "vllm_kunlun.ops.sample.sampler", "vllm.v1.sample.ops.topk_topp_sampler": "vllm_kunlun.v1.sample.ops.topk_topp_sampler", - "vllm.v1.sample.rejection_sampler": "vllm_kunlun.v1.sample.rejection_sampler" + "vllm.v1.sample.rejection_sampler": "vllm_kunlun.v1.sample.rejection_sampler", + "vllm.attention.ops.merge_attn_states": "vllm_kunlun.ops.attention.merge_attn_states" } if module_name in module_mappings: diff --git a/vllm_kunlun/ops/attention/merge_attn_states.py b/vllm_kunlun/ops/attention/merge_attn_states.py new file mode 100644 index 0000000..aaab8ad --- /dev/null +++ b/vllm_kunlun/ops/attention/merge_attn_states.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch +import xtorch_ops +from vllm.platforms import current_platform + + +def merge_attn_states( + output: torch.Tensor, + prefix_output: torch.Tensor, + prefix_lse: torch.Tensor, + suffix_output: torch.Tensor, + suffix_lse: torch.Tensor, + output_lse: Optional[torch.Tensor] = None, +) -> None: + + return xtorch_ops.attention_merge_stage( + prefix_output, + prefix_lse, + suffix_output, + suffix_lse, + output, + output_lse + ) \ No newline at end of file diff --git a/vllm_kunlun/v1/attention/backends/mla/common.py b/vllm_kunlun/v1/attention/backends/mla/common.py index a4cd521..60c901d 100644 --- a/vllm_kunlun/v1/attention/backends/mla/common.py +++ b/vllm_kunlun/v1/attention/backends/mla/common.py @@ -207,7 +207,6 @@ from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank -from vllm.distributed import get_tp_group from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, @@ -322,6 +321,7 @@ class MLACommonPrefillMetadata: # New for MLA (compared to FlashAttention) # For handling chunked prefill cu_seq_lens: torch.Tensor + cu_seq_lens_cpu: torch.Tensor starts: torch.Tensor seq_tot: list[int] max_seq_lens: list[int] @@ -795,6 +795,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): chunked_context_metadata_cls( cu_seq_lens=cu_seq_lens_cpu \ .to(device, non_blocking=True), + cu_seq_lens_cpu=cu_seq_lens_cpu, starts=cp_chunk_starts.to(device, non_blocking=True), seq_tot=cp_chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), @@ -812,6 +813,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): chunked_context_metadata_cls( cu_seq_lens=cu_seq_lens_cpu \ .to(device, non_blocking=True), + cu_seq_lens_cpu=cu_seq_lens_cpu, starts=chunk_starts.to(device, non_blocking=True), seq_tot=chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), @@ -1215,7 +1217,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): # ) attn_out = torch.empty_like(q) ds_alpha = 1.8738542070926265 - tp_q_head_num=128 + tp_q_head_num=q.size(1) softmax_lse = torch.zeros(tp_q_head_num, q.size(0), dtype=torch.float32, device=q.device) softmax_lse.fill_(float('-inf')) xtorch_ops.attention( @@ -1247,7 +1249,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): # Remain consistent with old `flash_attn_varlen_func` where there # is only one output tensor if `return_softmax_lse` is False. if return_softmax_lse: - return attn_out, lse + return attn_out, softmax_lse return attn_out def _run_prefill_new_tokens_fa(self, prefill: MLACommonPrefillMetadata, q, @@ -1310,7 +1312,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): k=k, v=v, context_seq_lod_xpu=prefill.query_start_loc, - context_seq_lod_cpu=prefill.chunked_context.cu_seq_lens[chunk_idx], + context_seq_lod_cpu=prefill.chunked_context.cu_seq_lens_cpu[chunk_idx], softmax_scale=self.scale, causal=False, return_softmax_lse=True, @@ -1548,7 +1550,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] - self.gather_and_maybe_dequant_cache_py_optimized( + torch.ops.xspeedgate_ops.gather_and_maybe_dequant_cache( src_cache=kv_c_and_k_pe_cache, dst=workspace, block_table=prefill_metadata.block_table, @@ -1743,12 +1745,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): q, kv_c_and_k_pe_cache, attn_metadata, k_scale) output = torch.empty_like(suffix_output) + output_lse = torch.empty_like(output) merge_attn_states( output=output, prefix_output=context_output, prefix_lse=context_lse, suffix_output=suffix_output, suffix_lse=suffix_lse, + output_lse=output_lse, ) # unpad if necessary