From ace300a54908cfa068acbfff231bc61e2a8ef6dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=97=A0=E8=84=B8=E7=94=B7?= <244036962@qq.com> Date: Sat, 11 Oct 2025 10:20:10 +0800 Subject: [PATCH] [Bugfix] Fix the abnormal NPU memory usage in full graph mode. (#3331) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? In the full graph mode, since paged attention operators updates are required, the parameters of this operators needs to be retained. However, the tensor such as query、key cache、value cache, does not need to be persistently saved, and we can manually release this space by `weak_ref_tensor` to save the memory. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: WithHades <244036962@qq.com> --- vllm_ascend/attention/attention_v1.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 98c8c57..bf881b1 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -39,6 +39,8 @@ from vllm_ascend.ops.attention import vanilla_chunked_prefill from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, nd_to_nz_2d, nd_to_nz_spec) +from ..utils import weak_ref_tensors + class AscendAttentionBackend(AttentionBackend): accept_output_buffer: bool = True @@ -402,15 +404,15 @@ class AscendAttentionBackendImpl(AttentionImpl): graph_params.events[num_tokens].append(event) graph_params.attn_params[num_tokens].append(( - query, - self.key_cache, - self.value_cache, + weak_ref_tensors(query), + weak_ref_tensors(self.key_cache), + weak_ref_tensors(self.value_cache), self.num_kv_heads, self.num_heads, self.scale, - attn_metadata.block_tables, + weak_ref_tensors(attn_metadata.block_tables), attn_metadata.seq_lens, - output, + weak_ref_tensors(output), )) torch.npu.graph_task_group_begin(stream)