diff --git a/vllm_ascend/attention/attention_cp.py b/vllm_ascend/attention/attention_cp.py index 66961a52..95ac7fcb 100644 --- a/vllm_ascend/attention/attention_cp.py +++ b/vllm_ascend/attention/attention_cp.py @@ -235,8 +235,6 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder): num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded, block_tables=block_table, query_start_loc=query_start_loc, - query_start_loc_list=query_start_loc_cpu[1:].tolist(), - query_lens=query_lens, seq_lens=seq_lens, seq_lens_list=seq_lens.tolist(), max_query_len=common_attn_metadata.max_query_len, diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 2abc18f9..f4107aa2 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -34,7 +34,7 @@ from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import AttentionSpec from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, - split_decodes_and_prefills, + enable_cp, split_decodes_and_prefills, using_paged_attention) from vllm_ascend.compilation.acl_graph import (get_graph_params, update_graph_params_workspaces) @@ -52,9 +52,7 @@ class AscendAttentionBackend(AttentionBackend): @staticmethod def get_impl_cls() -> Type["AscendAttentionBackendImpl"]: - prefill_config = get_current_vllm_config().parallel_config - if (prefill_config.prefill_context_parallel_size > 1 - or prefill_config.decode_context_parallel_size > 1): + if enable_cp(): from vllm_ascend.attention.attention_cp import \ AscendAttentionCPImpl return AscendAttentionCPImpl @@ -62,9 +60,7 @@ class AscendAttentionBackend(AttentionBackend): @staticmethod def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]: - prefill_config = get_current_vllm_config().parallel_config - if (prefill_config.prefill_context_parallel_size > 1 - or prefill_config.decode_context_parallel_size > 1): + if enable_cp(): from vllm_ascend.attention.attention_cp import \ AscendAttentionCPMetadataBuilder return AscendAttentionCPMetadataBuilder @@ -191,10 +187,8 @@ class AscendMetadata: seq_lens: torch.Tensor = None seq_lens_list: List[int] = None # type: ignore actual_seq_lengths_q: List[int] = None # type: ignore - query_start_loc_list: List[int] = None # type: ignore query_start_loc: torch.Tensor = None - query_lens: torch.Tensor = None # Maximum query length in the batch (None for decoding). max_query_len: Optional[int] = None @@ -214,9 +208,9 @@ class AscendMetadata: # dcp decode_meta: Optional[AscendMetadataForDecode] = None - # Whether is the pooling model with causal attention, - # used to guide the attention computation for pooling models. - is_causal_pooling: Optional[bool] = None + causal: bool = True + # runner_type in model_config. + model_runner_type: str = "" class AscendAttentionMetadataBuilder: @@ -276,11 +270,8 @@ class AscendAttentionMetadataBuilder: num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold) - assert num_decodes + num_prefills == num_reqs - assert num_decode_tokens + num_prefill_tokens == num_actual_tokens block_table = common_attn_metadata.block_table_tensor - query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata @@ -297,10 +288,6 @@ class AscendAttentionMetadataBuilder: # TODO: Yet another unnecessary H2D while we already have a query_start_loc on device query_start_loc = query_start_loc_cpu.pin_memory().to( self.device, non_blocking=True) - is_causal_pooling = None - if self.model_config.runner_type == "pooling": - is_causal_pooling = common_attn_metadata.causal if hasattr( - common_attn_metadata, 'causal') else True attn_metadata = AscendMetadata( num_actual_tokens=num_actual_tokens, @@ -308,8 +295,6 @@ class AscendAttentionMetadataBuilder: num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded, block_tables=block_table, query_start_loc=query_start_loc, - query_start_loc_list=query_start_loc_cpu[1:].tolist(), - query_lens=query_lens, seq_lens=seq_lens, seq_lens_list=seq_lens.tolist(), max_query_len=common_attn_metadata.max_query_len, @@ -319,7 +304,8 @@ class AscendAttentionMetadataBuilder: attn_state=attn_state, num_prefills=num_prefills, num_decodes=num_decodes, - is_causal_pooling=is_causal_pooling) + causal=common_attn_metadata.causal, + model_runner_type=self.model_config.runner_type) return attn_metadata def build_for_graph_capture( @@ -384,9 +370,9 @@ class AscendAttentionBackendImpl(AttentionImpl): key, value, block_size, block_table, actual_seq_lengths_kv \ = self._get_fia_params(key, value, attn_metadata) - num_tokens = attn_metadata.query_start_loc_list[-1] + num_tokens = attn_metadata.actual_seq_lengths_q[-1] graph_params = get_graph_params() - query_start_loc = attn_metadata.query_start_loc_list + actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q # Prepare tensors for attention output # TODO: Refactor this to step-level instead of layer-level @@ -402,7 +388,7 @@ class AscendAttentionBackendImpl(AttentionImpl): block_table=block_table, input_layout="TND", block_size=block_size, - actual_seq_lengths=query_start_loc, + actual_seq_lengths=actual_seq_lengths_q, actual_seq_lengths_kv=actual_seq_lengths_kv, num_key_value_heads=self.num_kv_heads, num_heads=self.num_heads, @@ -422,7 +408,7 @@ class AscendAttentionBackendImpl(AttentionImpl): (weak_ref_tensors(query), weak_ref_tensors(key), weak_ref_tensors(value), weak_ref_tensors(block_table), weak_ref_tensors(attn_metadata.attn_mask), block_size, - actual_seq_lengths_kv, query_start_loc, self.num_kv_heads, + actual_seq_lengths_kv, actual_seq_lengths_q, self.num_kv_heads, self.num_heads, self.scale, weak_ref_tensors(output), weak_ref_tensors(softmax_lse))) @@ -435,7 +421,7 @@ class AscendAttentionBackendImpl(AttentionImpl): block_table=block_table, input_layout="TND", block_size=block_size, - actual_seq_lengths=query_start_loc, + actual_seq_lengths=actual_seq_lengths_q, actual_seq_lengths_kv=actual_seq_lengths_kv, num_key_value_heads=self.num_kv_heads, num_heads=self.num_heads, @@ -518,10 +504,10 @@ class AscendAttentionBackendImpl(AttentionImpl): if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: block_size = 128 block_table = None - actual_seq_lengths_kv = attn_metadata.query_start_loc_list + actual_seq_lengths_kv = attn_metadata.actual_seq_lengths_q elif attn_metadata.attn_state == \ AscendAttentionState.PrefillCacheHit: - batch_size = attn_metadata.query_lens.shape[0] + batch_size = attn_metadata.seq_lens.shape[0] block_table = attn_metadata.block_tables[:batch_size, :] num_block, block_size, _, _ = self.key_cache.shape # type: ignore key = self.key_cache.view( # type: ignore @@ -644,9 +630,8 @@ class AscendAttentionBackendImpl(AttentionImpl): attn_metadata: AscendMetadata, _: torch.Tensor) -> torch.Tensor: assert attn_metadata is not None - assert attn_metadata.is_causal_pooling is not None - if attn_metadata.is_causal_pooling: + if attn_metadata.causal: # use sparse_mode 3 in causal scenario return torch_npu.npu_fusion_attention( query=query, @@ -768,7 +753,7 @@ class AscendAttentionBackendImpl(AttentionImpl): key, value = self.reshape_and_cache(key, value, kv_cache, attn_metadata) # pooling model branch - if isinstance(attn_metadata.is_causal_pooling, bool): + if attn_metadata.model_runner_type == "pooling": attn_output = self._forward_encoder_attention( query, key, value, attn_metadata, output) output[:num_tokens] = attn_output[:num_tokens] diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 17a5f04e..352f1d33 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -23,6 +23,7 @@ from vllm_ascend import envs from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, + enable_cp, maybe_save_kv_layer_to_connector, split_decodes_and_prefills, trans_rope_weight, transdata, @@ -57,8 +58,7 @@ class AscendMLABackend(AttentionBackend): @staticmethod def get_builder_cls(): - prefill_config = get_current_vllm_config().parallel_config - if prefill_config.prefill_context_parallel_size > 1 or prefill_config.decode_context_parallel_size > 1: + if enable_cp(): from vllm_ascend.attention.mla_cp import AscendMlaCPMetadataBuilder return AscendMlaCPMetadataBuilder return AscendMLAMetadataBuilder @@ -70,8 +70,7 @@ class AscendMLABackend(AttentionBackend): @staticmethod def get_impl_cls() -> Type["MLAAttentionImpl"]: - prefill_config = get_current_vllm_config().parallel_config - if prefill_config.prefill_context_parallel_size > 1 or prefill_config.decode_context_parallel_size > 1: + if enable_cp(): from vllm_ascend.attention.mla_cp import AscendMlaCPImpl return AscendMlaCPImpl return AscendMLAImpl diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index ac19dc9d..22200256 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -1,9 +1,10 @@ from dataclasses import dataclass +from functools import lru_cache from typing import Any, List, Optional import torch import torch.nn.functional as F -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group, is_v1_kv_transfer_group) @@ -26,6 +27,13 @@ def using_paged_attention(runtime_shape: int, vllm_config: VllmConfig) -> bool: return runtime_shape in get_ascend_config().pa_shape_list +@lru_cache(maxsize=1) +def enable_cp(): + prefill_config = get_current_vllm_config().parallel_config + return prefill_config.prefill_context_parallel_size > 1 \ + or prefill_config.decode_context_parallel_size > 1 + + @dataclass # class AscendCommonLongSequenceMetadata: class AscendPrefillContextParallelMetadata: @@ -66,7 +74,7 @@ class AscendCommonAttentionMetadata: Per-batch attention metadata, shared across layers and backends. AttentionMetadataBuilder instances use it to construct per-layer metadata. - For many of the tensors we keep both GPU and CPU versions. + For many of the tensors we keep both NPU and CPU versions. """ query_start_loc: torch.Tensor @@ -109,8 +117,6 @@ class AscendCommonAttentionMetadata: attn_state: Any = None - is_only_prefill: bool = False - graph_pad_size: int = -1 # num_input_tokens refers to total number of tokens including @@ -120,6 +126,8 @@ class AscendCommonAttentionMetadata: prefill_context_parallel_metadata: Optional[ AscendPrefillContextParallelMetadata] = None + causal: bool = True + # TODO: Remove it when vLLM no longer uses this function. def unpadded(self, num_actual_tokens: int, num_actual_reqs: int) -> "AscendCommonAttentionMetadata": @@ -137,12 +145,12 @@ class AscendCommonAttentionMetadata: decode_token_per_req=self.decode_token_per_req, block_table_tensor=self.block_table_tensor[:num_actual_reqs], slot_mapping=self.slot_mapping[:num_actual_tokens], + causal=self.causal, actual_seq_lengths_q=self.actual_seq_lengths_q[:num_actual_tokens], positions=self.positions[:num_actual_tokens], attn_mask=self.attn_mask, spec_attn_mask=self.spec_attn_mask, attn_state=self.attn_state, - is_only_prefill=self.is_only_prefill, graph_pad_size=-1, # It should be -1 when not run in fullgraph mode. num_input_tokens=num_actual_tokens, prefill_context_parallel_metadata=self. diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index c4990497..c3c52cb2 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -271,8 +271,8 @@ def _update_attn_fia_params(update_stream, forward_context, runtime_shape): attn_output, softmax_lse) = param seq_lens = forward_context.attn_metadata[key].seq_lens_list - query_start_loc = forward_context.attn_metadata[ - key].query_start_loc_list + actual_seq_lengths_q = forward_context.attn_metadata[ + key].actual_seq_lengths_q torch.npu.graph_task_update_begin(update_stream, handle) torch_npu.npu_fused_infer_attention_score.out( query=query, @@ -282,7 +282,7 @@ def _update_attn_fia_params(update_stream, forward_context, runtime_shape): atten_mask=attn_mask, input_layout="TND", block_size=block_size, - actual_seq_lengths=query_start_loc, + actual_seq_lengths=actual_seq_lengths_q, actual_seq_lengths_kv=seq_lens, num_key_value_heads=num_kv_heads, num_heads=num_heads, diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 0518aa4a..356efad6 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -384,12 +384,11 @@ class EagleProposer(Proposer): attn_metadata.num_actual_tokens = batch_size attn_metadata.max_query_len = 1 attn_metadata.query_start_loc = self.arange_cpu[:batch_size + 1] - attn_metadata.query_start_loc_list = attn_metadata.query_start_loc[ - 1:].tolist() attn_metadata.num_decodes, attn_metadata.num_prefills, attn_metadata.num_decode_tokens, attn_metadata.num_prefill_tokens = 0, batch_size, 0, batch_size attn_metadata.num_actual_tokens_pcp_padded = attn_metadata.num_decode_tokens + attn_metadata.num_prefill_tokens - attn_metadata.actual_seq_lengths_q = [1 + i for i in range(batch_size)] + attn_metadata.actual_seq_lengths_q = attn_metadata.query_start_loc[ + 1:].tolist() attn_metadata.seq_lens_list = attn_metadata.seq_lens.tolist() attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill for now_speculative in range( diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 94ae0916..38cfcd0c 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1042,7 +1042,6 @@ class NPUModelRunner(GPUModelRunner): attn_mask=self.attn_mask, spec_attn_mask=self.spec_attn_mask, attn_state=self.attn_state, - is_only_prefill=bool(np.all(num_valid_tokens != 1)), max_query_len=max_num_scheduled_tokens, decode_token_per_req=self.decode_token_per_req, prefill_context_parallel_metadata=long_seq_metadata, diff --git a/vllm_ascend/worker/v2/attn_utils.py b/vllm_ascend/worker/v2/attn_utils.py index ad2caa55..eb536eaa 100644 --- a/vllm_ascend/worker/v2/attn_utils.py +++ b/vllm_ascend/worker/v2/attn_utils.py @@ -45,7 +45,6 @@ def build_attn_metadata( | None = None, spec_attn_mask: torch.Tensor | None = None, attn_state: Any | None = None, - is_only_prefill: bool = False, graph_pad_size: int = -1, num_input_tokens: int = 0, prefill_context_parallel_metadata: AscendPrefillContextParallelMetadata @@ -78,7 +77,6 @@ def build_attn_metadata( attn_mask=attn_mask, spec_attn_mask=spec_attn_mask, attn_state=attn_state, - is_only_prefill=is_only_prefill, graph_pad_size=graph_pad_size, num_input_tokens=num_input_tokens, prefill_context_parallel_metadata=prefill_context_parallel_metadata, diff --git a/vllm_ascend/xlite/xlite.py b/vllm_ascend/xlite/xlite.py index 6ffac6db..c41734a4 100644 --- a/vllm_ascend/xlite/xlite.py +++ b/vllm_ascend/xlite/xlite.py @@ -247,7 +247,9 @@ class XliteWrapper: if not with_prefill or self.full_mode: batch = attn_metadata.num_prefills + attn_metadata.num_decodes seq_lens = attn_metadata.seq_lens[:batch] - query_lens = attn_metadata.query_lens[:batch] + query_lens = attn_metadata.query_start_loc_cpu[ + 1:] - attn_metadata.query_start_loc_cpu[:-1] + query_lens = query_lens[:batch] cached_lens = seq_lens - query_lens xlite_attn_metadata = ModelAttnMeta()