From 8ab8111fdea3f47391c4245b9db4be624fa6452a Mon Sep 17 00:00:00 2001 From: Yizhou <136800916+yiz-liu@users.noreply.github.com> Date: Sat, 25 Oct 2025 20:37:33 +0800 Subject: [PATCH] [Fix] Prevent memory leak in MLA decode graph (#3743) ### What this PR does / why we need it? The cache for MLA decode graph parameters was holding strong references to tensors, preventing them from being garbage collected and leading to increased memory usage. This change wraps the cached tensors in weak references, allowing them to be deallocated when no longer in use and reducing overall memory pressure. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? None. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/c9461e05a4ed3557cfbf4b15ded1e26761cc39ca --------- Signed-off-by: Yizhou Liu --- vllm_ascend/attention/attention_v1.py | 5 +++-- vllm_ascend/attention/mla_v1.py | 28 +++++++++++++++------------ vllm_ascend/compilation/acl_graph.py | 8 +++----- vllm_ascend/utils.py | 7 +++++++ 4 files changed, 29 insertions(+), 19 deletions(-) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 4b1f279f..ef36c04c 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -562,7 +562,8 @@ class AscendAttentionBackendImpl(AttentionImpl): block_table=attn_metadata.block_tables, context_lens=attn_metadata.seq_lens, out=output) - update_graph_params_workspaces(num_tokens, workspace) + update_graph_params_workspaces( + num_tokens, weak_ref_tensors(workspace)) # Handle graph capturing mode stream = torch_npu.npu.current_stream() @@ -578,7 +579,7 @@ class AscendAttentionBackendImpl(AttentionImpl): self.num_kv_heads, self.num_heads, self.scale, - weak_ref_tensors(attn_metadata.block_tables), + attn_metadata.block_tables, attn_metadata.seq_lens, weak_ref_tensors(output), )) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index c4503f59..ebba38c1 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -12,8 +12,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata, MLAAttentionImpl) from vllm.config import VllmConfig, get_current_vllm_config - -# isort: off from vllm.distributed import (get_dcp_group, get_decode_context_model_parallel_rank, get_decode_context_model_parallel_world_size, @@ -35,19 +33,22 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, split_decodes_and_prefills, trans_rope_weight, transdata, wait_for_kv_layer_from_connector) -from vllm_ascend.compilation.acl_graph import get_graph_params +from vllm_ascend.compilation.acl_graph import (get_graph_params, + update_graph_params_workspaces) from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, - is_enable_nz, prefill_context_parallel_enable) + is_enable_nz, prefill_context_parallel_enable, + weak_ref_tensors) from vllm_ascend.worker.npu_input_batch import InputBatch +# isort: off if prefill_context_parallel_enable(): from vllm.distributed import (get_pcp_group, get_prefill_context_model_parallel_rank, get_prefill_context_model_parallel_world_size ) -# isort:on +# isort: on if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -743,7 +744,7 @@ class AscendMLAImpl(MLAAttentionImpl): getattr(self.fused_qkv_a_proj.quant_method, 'quant_method', None), AscendW8A8LinearMethod): self.enable_mlapo = False - logger.warning( + logger.warning_once( "Currently mlapo only supports W8A8 quantization in MLA scenario." "Some layers in your model are not quantized with W8A8," "thus mlapo is disabled for these layers.") @@ -1115,7 +1116,8 @@ class AscendMLAImpl(MLAAttentionImpl): if workspace is None: workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( q_nope, k_nope, k_nope, **common_kwargs) - graph_params.workspaces[num_tokens] = workspace + update_graph_params_workspaces(num_tokens, + weak_ref_tensors(workspace)) attn_output = torch.empty_like(q_nope) softmax_lse = torch.empty(num_tokens, @@ -1123,11 +1125,13 @@ class AscendMLAImpl(MLAAttentionImpl): device=q_nope.device) graph_params.attn_params[num_tokens].append( - (q_nope, k_nope, q_pe, k_pe, self.num_heads, self.num_kv_heads, - input_layout, spec_attn_mask, sparse_mode, self.scale, - decode_meta.block_table, block_size, - decode_meta.seq_lens_list, actual_seq_lengths, workspace, - attn_output, softmax_lse)) + (weak_ref_tensors(q_nope), weak_ref_tensors(k_nope), + weak_ref_tensors(q_pe), weak_ref_tensors(k_pe), + self.num_heads, self.num_kv_heads, input_layout, + weak_ref_tensors(spec_attn_mask) if spec_attn_mask is not None + else None, sparse_mode, self.scale, decode_meta.block_table, + block_size, decode_meta.seq_lens_list, actual_seq_lengths, + weak_ref_tensors(attn_output), weak_ref_tensors(softmax_lse))) torch.npu.graph_task_group_begin(stream) torch_npu.npu_fused_infer_attention_score.out( diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 2ba6b253..91d75b52 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -212,7 +212,6 @@ def update_attn_params(update_stream, forward_context, runtime_shape): seq_lens, output, ) = param - # block_table = forward_context.attn_metadata[key].block_tables seq_lens = forward_context.attn_metadata[key].seq_lens torch_npu_check = version_check() @@ -258,8 +257,7 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape, ): (q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout, spec_attn_mask, sparse_mode, scale, block_table, block_size, - seq_lens_list, actual_seq_lengths, workspace, attn_output, - softmax_lse) = param + seq_lens_list, actual_seq_lengths, attn_output, softmax_lse) = param seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list if speculative_config and speculative_config.method == "deepseek_mtp": actual_seq_lengths = forward_context.attn_metadata[ @@ -295,7 +293,7 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape, block_size=block_size, actual_seq_lengths_kv=seq_lens_list, actual_seq_lengths=actual_seq_lengths, - workspace=workspace, + workspace=graph_params.workspaces.get(runtime_shape), out=[attn_output, softmax_lse]) torch.npu.graph_task_update_end(update_stream) @@ -329,7 +327,7 @@ def set_graph_params(aclgraph_capture_sizes: set[int]): ) -def update_graph_params_workspaces(num_tokens: int, workspace: int): +def update_graph_params_workspaces(num_tokens: int, workspace: Any): global _graph_params if _graph_params is not None: _graph_params.workspaces[num_tokens] = workspace diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 36dedc5c..a3b908a0 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -697,6 +697,13 @@ def weak_ref_tensors( """ Convenience function to create weak references to tensors, for single tensor, list of tensors or tuple of tensors. + + This function should be used in the following scenario: + When a tensor is created during graph capture, and it's held by a method + that's not part of the graph, we don't really need to store it, but we + **do need** its buffer pointer. If we don't handle this, it cannot + be garbage collected, leading to a memory leak. To avoid this, + we should create a weak reference to the tensor. """ if isinstance(tensors, torch.Tensor): return weak_ref_tensor(tensors)