[v0.11.0][Fix] Prevent memory leak in MLA decode graph (#3743) (#3774)

### What this PR does / why we need it?
The cache for MLA decode graph parameters was holding strong references
to tensors, preventing them from being garbage collected and leading to
increased memory usage.

This change wraps the cached tensors in weak references, allowing them
to be deallocated when no longer in use and reducing overall memory
pressure.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
None.

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
Yizhou
2025-10-27 16:00:20 +08:00
committed by GitHub
parent 825fdfb197
commit 43276fd822
4 changed files with 26 additions and 16 deletions

View File

@@ -443,7 +443,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
update_graph_params_workspaces(num_tokens, workspace)
update_graph_params_workspaces(
num_tokens, weak_ref_tensors(workspace))
# Handle graph capturing mode
stream = torch_npu.npu.current_stream()
@@ -459,7 +460,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
self.num_kv_heads,
self.num_heads,
self.scale,
weak_ref_tensors(attn_metadata.block_tables),
attn_metadata.block_tables,
attn_metadata.seq_lens,
weak_ref_tensors(output),
))

View File

@@ -25,14 +25,15 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
split_decodes_and_prefills,
trans_rope_weight, transdata,
wait_for_kv_layer_from_connector)
from vllm_ascend.compilation.acl_graph import get_graph_params
from vllm_ascend.compilation.acl_graph import (get_graph_params,
update_graph_params_workspaces)
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.context import get_multistream_comm_context
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
is_enable_nz)
is_enable_nz, weak_ref_tensors)
from vllm_ascend.worker.npu_input_batch import InputBatch
if TYPE_CHECKING:
@@ -663,7 +664,7 @@ class AscendMLAImpl(MLAAttentionImpl):
getattr(self.fused_qkv_a_proj.quant_method, 'quant_method',
None), AscendW8A8LinearMethod):
self.enable_mlapo = False
logger.warning(
logger.warning_once(
"Currently mlapo only supports W8A8 quantization in MLA scenario."
"Some layers in your model are not quantized with W8A8,"
"thus mlapo is disabled for these layers.")
@@ -1035,7 +1036,8 @@ class AscendMLAImpl(MLAAttentionImpl):
if workspace is None:
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
q_nope, k_nope, k_nope, **common_kwargs)
graph_params.workspaces[num_tokens] = workspace
update_graph_params_workspaces(num_tokens,
weak_ref_tensors(workspace))
attn_output = torch.empty_like(q_nope)
softmax_lse = torch.empty(num_tokens,
@@ -1043,11 +1045,13 @@ class AscendMLAImpl(MLAAttentionImpl):
device=q_nope.device)
graph_params.attn_params[num_tokens].append(
(q_nope, k_nope, q_pe, k_pe, self.num_heads, self.num_kv_heads,
input_layout, spec_attn_mask, sparse_mode, self.scale,
decode_meta.block_table, block_size,
decode_meta.seq_lens_list, actual_seq_lengths, workspace,
attn_output, softmax_lse))
(weak_ref_tensors(q_nope), weak_ref_tensors(k_nope),
weak_ref_tensors(q_pe), weak_ref_tensors(k_pe),
self.num_heads, self.num_kv_heads, input_layout,
weak_ref_tensors(spec_attn_mask) if spec_attn_mask is not None
else None, sparse_mode, self.scale, decode_meta.block_table,
block_size, decode_meta.seq_lens_list, actual_seq_lengths,
weak_ref_tensors(attn_output), weak_ref_tensors(softmax_lse)))
torch.npu.graph_task_group_begin(stream)
torch_npu.npu_fused_infer_attention_score.out(

View File

@@ -212,7 +212,6 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
seq_lens,
output,
) = param
# block_table = forward_context.attn_metadata[key].block_tables
seq_lens = forward_context.attn_metadata[key].seq_lens
torch_npu_check = version_check()
@@ -258,8 +257,7 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
):
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
spec_attn_mask, sparse_mode, scale, block_table, block_size,
seq_lens_list, actual_seq_lengths, workspace, attn_output,
softmax_lse) = param
seq_lens_list, actual_seq_lengths, attn_output, softmax_lse) = param
seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list
if speculative_config and speculative_config.method == "deepseek_mtp":
actual_seq_lengths = forward_context.attn_metadata[
@@ -295,7 +293,7 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
block_size=block_size,
actual_seq_lengths_kv=seq_lens_list,
actual_seq_lengths=actual_seq_lengths,
workspace=workspace,
workspace=graph_params.workspaces.get(runtime_shape),
out=[attn_output, softmax_lse])
torch.npu.graph_task_update_end(update_stream)
@@ -329,7 +327,7 @@ def set_graph_params(aclgraph_capture_sizes: set[int]):
)
def update_graph_params_workspaces(num_tokens: int, workspace: int):
def update_graph_params_workspaces(num_tokens: int, workspace: Any):
global _graph_params
if _graph_params is not None:
_graph_params.workspaces[num_tokens] = workspace

View File

@@ -686,6 +686,13 @@ def weak_ref_tensors(
"""
Convenience function to create weak references to tensors,
for single tensor, list of tensors or tuple of tensors.
This function should be used in the following scenario:
When a tensor is created during graph capture, and it's held by a method
that's not part of the graph, we don't really need to store it, but we
**do need** its buffer pointer. If we don't handle this, it cannot
be garbage collected, leading to a memory leak. To avoid this,
we should create a weak reference to the tensor.
"""
if isinstance(tensors, torch.Tensor):
return weak_ref_tensor(tensors)