diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index c6524a78..77b5251d 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -154,8 +154,7 @@ class AscendConfig: # npu_fused_infer_attention_score in some cases. We allow to execute # _npu_paged_attention in this cases. This should be removed once # npu_fused_infer_attention_score performs better on all scenarios. - self.pa_shape_list = additional_config.get("pa_shape_list", - [1, 2, 3, 4]) + self.pa_shape_list = additional_config.get("pa_shape_list", []) kv_cfg = vllm_config.kv_transfer_config if kv_cfg is not None and not getattr(kv_cfg, "_engine_id_patched", diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 003da21f..9121ae19 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -367,6 +367,7 @@ class AscendAttentionBackendImpl(AttentionImpl): kv_sharing_target_layer_name: Optional[str], **kwargs, ) -> None: + self.vllm_config = get_current_vllm_config() self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) @@ -723,7 +724,7 @@ class AscendAttentionBackendImpl(AttentionImpl): ): num_tokens = query.shape[0] if (attn_metadata.attn_state == AscendAttentionState.DecodeOnly - and using_paged_attention(num_tokens) + and using_paged_attention(num_tokens, self.vllm_config) and self.sliding_window is None): output = self.forward_paged_attention(query, attn_metadata, output) else: diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index 08a17fbc..ac19dc9d 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -1,10 +1,9 @@ from dataclasses import dataclass -from functools import lru_cache from typing import Any, List, Optional import torch import torch.nn.functional as F -from vllm.config import get_current_vllm_config +from vllm.config import VllmConfig from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group, is_v1_kv_transfer_group) @@ -14,9 +13,7 @@ from vllm_ascend.utils import (AscendDeviceType, get_ascend_config, get_ascend_device_type) -@lru_cache -def using_paged_attention(runtime_shape: int) -> bool: - vllm_config = get_current_vllm_config() +def using_paged_attention(runtime_shape: int, vllm_config: VllmConfig) -> bool: if vllm_config.speculative_config is not None: return False if get_ascend_device_type() == AscendDeviceType.A5: diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index ab1c6ae2..c4990497 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -296,8 +296,9 @@ def _update_attn_fia_params(update_stream, forward_context, runtime_shape): event.record(update_stream) -def update_attn_params(update_stream, forward_context, runtime_shape): - if using_paged_attention(runtime_shape): +def update_attn_params(update_stream, forward_context, runtime_shape, + vllm_config): + if using_paged_attention(runtime_shape, vllm_config): _update_attn_pa_params(update_stream, forward_context, runtime_shape) else: _update_attn_fia_params(update_stream, forward_context, runtime_shape) diff --git a/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py b/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py index d0f1aa53..2f34c408 100644 --- a/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py +++ b/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py @@ -23,8 +23,7 @@ from torch._inductor.pattern_matcher import (PatternMatcherPass, PatternPrettyPrinter) from vllm.attention.layer import Attention from vllm.compilation.vllm_inductor_pass import VllmInductorPass -from vllm.config import (VllmConfig, get_current_vllm_config, - get_layers_from_vllm_config) +from vllm.config import VllmConfig, get_layers_from_vllm_config class QKNormRopeFusionPattern: @@ -42,7 +41,6 @@ class QKNormRopeFusionPattern: self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.eps = eps - vllm_config = get_current_vllm_config() self.device = vllm_config.device_config.device if vllm_config.device_config else None def get_inputs(self): diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 9ab91214..ba69cff1 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1165,7 +1165,8 @@ class NPUModelRunner(GPUModelRunner): maybe_padded_num_tokens) else: update_attn_params(self.update_stream, forward_context, - maybe_padded_num_tokens) + maybe_padded_num_tokens, + self.vllm_config) if get_forward_context().sp_enabled and not isinstance( hidden_states, IntermediateTensors): @@ -1957,7 +1958,7 @@ class NPUModelRunner(GPUModelRunner): positions.shape[0]) else: update_attn_params(self.update_stream, forward_context, - num_tokens) + num_tokens, self.vllm_config) if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3: hidden_states, _ = hidden_states