[Fix] Prevent memory leak in MLA decode graph (#3743)

### What this PR does / why we need it?
The cache for MLA decode graph parameters was holding strong references
to tensors, preventing them from being garbage collected and leading to
increased memory usage.

This change wraps the cached tensors in weak references, allowing them
to be deallocated when no longer in use and reducing overall memory
pressure.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
None.

- vLLM version: v0.11.0rc3
- vLLM main:
c9461e05a4

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
Yizhou
2025-10-25 20:37:33 +08:00
committed by GitHub
parent afc58184ec
commit 8ab8111fde
4 changed files with 29 additions and 19 deletions

View File

@@ -12,8 +12,6 @@ from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata,
MLAAttentionImpl)
from vllm.config import VllmConfig, get_current_vllm_config
# isort: off
from vllm.distributed import (get_dcp_group,
get_decode_context_model_parallel_rank,
get_decode_context_model_parallel_world_size,
@@ -35,19 +33,22 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
split_decodes_and_prefills,
trans_rope_weight, transdata,
wait_for_kv_layer_from_connector)
from vllm_ascend.compilation.acl_graph import get_graph_params
from vllm_ascend.compilation.acl_graph import (get_graph_params,
update_graph_params_workspaces)
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
is_enable_nz, prefill_context_parallel_enable)
is_enable_nz, prefill_context_parallel_enable,
weak_ref_tensors)
from vllm_ascend.worker.npu_input_batch import InputBatch
# isort: off
if prefill_context_parallel_enable():
from vllm.distributed import (get_pcp_group,
get_prefill_context_model_parallel_rank,
get_prefill_context_model_parallel_world_size
)
# isort:on
# isort: on
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
@@ -743,7 +744,7 @@ class AscendMLAImpl(MLAAttentionImpl):
getattr(self.fused_qkv_a_proj.quant_method, 'quant_method',
None), AscendW8A8LinearMethod):
self.enable_mlapo = False
logger.warning(
logger.warning_once(
"Currently mlapo only supports W8A8 quantization in MLA scenario."
"Some layers in your model are not quantized with W8A8,"
"thus mlapo is disabled for these layers.")
@@ -1115,7 +1116,8 @@ class AscendMLAImpl(MLAAttentionImpl):
if workspace is None:
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
q_nope, k_nope, k_nope, **common_kwargs)
graph_params.workspaces[num_tokens] = workspace
update_graph_params_workspaces(num_tokens,
weak_ref_tensors(workspace))
attn_output = torch.empty_like(q_nope)
softmax_lse = torch.empty(num_tokens,
@@ -1123,11 +1125,13 @@ class AscendMLAImpl(MLAAttentionImpl):
device=q_nope.device)
graph_params.attn_params[num_tokens].append(
(q_nope, k_nope, q_pe, k_pe, self.num_heads, self.num_kv_heads,
input_layout, spec_attn_mask, sparse_mode, self.scale,
decode_meta.block_table, block_size,
decode_meta.seq_lens_list, actual_seq_lengths, workspace,
attn_output, softmax_lse))
(weak_ref_tensors(q_nope), weak_ref_tensors(k_nope),
weak_ref_tensors(q_pe), weak_ref_tensors(k_pe),
self.num_heads, self.num_kv_heads, input_layout,
weak_ref_tensors(spec_attn_mask) if spec_attn_mask is not None
else None, sparse_mode, self.scale, decode_meta.block_table,
block_size, decode_meta.seq_lens_list, actual_seq_lengths,
weak_ref_tensors(attn_output), weak_ref_tensors(softmax_lse)))
torch.npu.graph_task_group_begin(stream)
torch_npu.npu_fused_infer_attention_score.out(