diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index f0bf060..799853f 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -443,7 +443,8 @@ class AscendAttentionBackendImpl(AttentionImpl): block_table=attn_metadata.block_tables, context_lens=attn_metadata.seq_lens, out=output) - update_graph_params_workspaces(num_tokens, workspace) + update_graph_params_workspaces( + num_tokens, weak_ref_tensors(workspace)) # Handle graph capturing mode stream = torch_npu.npu.current_stream() @@ -459,7 +460,7 @@ class AscendAttentionBackendImpl(AttentionImpl): self.num_kv_heads, self.num_heads, self.scale, - weak_ref_tensors(attn_metadata.block_tables), + attn_metadata.block_tables, attn_metadata.seq_lens, weak_ref_tensors(output), )) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index cb9d6d0..977c114 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -25,14 +25,15 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, split_decodes_and_prefills, trans_rope_weight, transdata, wait_for_kv_layer_from_connector) -from vllm_ascend.compilation.acl_graph import get_graph_params +from vllm_ascend.compilation.acl_graph import (get_graph_params, + update_graph_params_workspaces) from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig from vllm_ascend.multistream.context import get_multistream_comm_context from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, - is_enable_nz) + is_enable_nz, weak_ref_tensors) from vllm_ascend.worker.npu_input_batch import InputBatch if TYPE_CHECKING: @@ -663,7 +664,7 @@ class AscendMLAImpl(MLAAttentionImpl): getattr(self.fused_qkv_a_proj.quant_method, 'quant_method', None), AscendW8A8LinearMethod): self.enable_mlapo = False - logger.warning( + logger.warning_once( "Currently mlapo only supports W8A8 quantization in MLA scenario." "Some layers in your model are not quantized with W8A8," "thus mlapo is disabled for these layers.") @@ -1035,7 +1036,8 @@ class AscendMLAImpl(MLAAttentionImpl): if workspace is None: workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( q_nope, k_nope, k_nope, **common_kwargs) - graph_params.workspaces[num_tokens] = workspace + update_graph_params_workspaces(num_tokens, + weak_ref_tensors(workspace)) attn_output = torch.empty_like(q_nope) softmax_lse = torch.empty(num_tokens, @@ -1043,11 +1045,13 @@ class AscendMLAImpl(MLAAttentionImpl): device=q_nope.device) graph_params.attn_params[num_tokens].append( - (q_nope, k_nope, q_pe, k_pe, self.num_heads, self.num_kv_heads, - input_layout, spec_attn_mask, sparse_mode, self.scale, - decode_meta.block_table, block_size, - decode_meta.seq_lens_list, actual_seq_lengths, workspace, - attn_output, softmax_lse)) + (weak_ref_tensors(q_nope), weak_ref_tensors(k_nope), + weak_ref_tensors(q_pe), weak_ref_tensors(k_pe), + self.num_heads, self.num_kv_heads, input_layout, + weak_ref_tensors(spec_attn_mask) if spec_attn_mask is not None + else None, sparse_mode, self.scale, decode_meta.block_table, + block_size, decode_meta.seq_lens_list, actual_seq_lengths, + weak_ref_tensors(attn_output), weak_ref_tensors(softmax_lse))) torch.npu.graph_task_group_begin(stream) torch_npu.npu_fused_infer_attention_score.out( diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 2ba6b25..91d75b5 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -212,7 +212,6 @@ def update_attn_params(update_stream, forward_context, runtime_shape): seq_lens, output, ) = param - # block_table = forward_context.attn_metadata[key].block_tables seq_lens = forward_context.attn_metadata[key].seq_lens torch_npu_check = version_check() @@ -258,8 +257,7 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape, ): (q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout, spec_attn_mask, sparse_mode, scale, block_table, block_size, - seq_lens_list, actual_seq_lengths, workspace, attn_output, - softmax_lse) = param + seq_lens_list, actual_seq_lengths, attn_output, softmax_lse) = param seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list if speculative_config and speculative_config.method == "deepseek_mtp": actual_seq_lengths = forward_context.attn_metadata[ @@ -295,7 +293,7 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape, block_size=block_size, actual_seq_lengths_kv=seq_lens_list, actual_seq_lengths=actual_seq_lengths, - workspace=workspace, + workspace=graph_params.workspaces.get(runtime_shape), out=[attn_output, softmax_lse]) torch.npu.graph_task_update_end(update_stream) @@ -329,7 +327,7 @@ def set_graph_params(aclgraph_capture_sizes: set[int]): ) -def update_graph_params_workspaces(num_tokens: int, workspace: int): +def update_graph_params_workspaces(num_tokens: int, workspace: Any): global _graph_params if _graph_params is not None: _graph_params.workspaces[num_tokens] = workspace diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 6745c16..55c8ee7 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -686,6 +686,13 @@ def weak_ref_tensors( """ Convenience function to create weak references to tensors, for single tensor, list of tensors or tuple of tensors. + + This function should be used in the following scenario: + When a tensor is created during graph capture, and it's held by a method + that's not part of the graph, we don't really need to store it, but we + **do need** its buffer pointer. If we don't handle this, it cannot + be garbage collected, leading to a memory leak. To avoid this, + we should create a weak reference to the tensor. """ if isinstance(tensors, torch.Tensor): return weak_ref_tensor(tensors)