### What this PR does / why we need it? The cache for MLA decode graph parameters was holding strong references to tensors, preventing them from being garbage collected and leading to increased memory usage. This change wraps the cached tensors in weak references, allowing them to be deallocated when no longer in use and reducing overall memory pressure. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? None. Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -25,14 +25,15 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
split_decodes_and_prefills,
|
||||
trans_rope_weight, transdata,
|
||||
wait_for_kv_layer_from_connector)
|
||||
from vllm_ascend.compilation.acl_graph import get_graph_params
|
||||
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
||||
update_graph_params_workspaces)
|
||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||
from vllm_ascend.multistream.context import get_multistream_comm_context
|
||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
is_enable_nz)
|
||||
is_enable_nz, weak_ref_tensors)
|
||||
from vllm_ascend.worker.npu_input_batch import InputBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -663,7 +664,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
getattr(self.fused_qkv_a_proj.quant_method, 'quant_method',
|
||||
None), AscendW8A8LinearMethod):
|
||||
self.enable_mlapo = False
|
||||
logger.warning(
|
||||
logger.warning_once(
|
||||
"Currently mlapo only supports W8A8 quantization in MLA scenario."
|
||||
"Some layers in your model are not quantized with W8A8,"
|
||||
"thus mlapo is disabled for these layers.")
|
||||
@@ -1035,7 +1036,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
if workspace is None:
|
||||
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
||||
q_nope, k_nope, k_nope, **common_kwargs)
|
||||
graph_params.workspaces[num_tokens] = workspace
|
||||
update_graph_params_workspaces(num_tokens,
|
||||
weak_ref_tensors(workspace))
|
||||
|
||||
attn_output = torch.empty_like(q_nope)
|
||||
softmax_lse = torch.empty(num_tokens,
|
||||
@@ -1043,11 +1045,13 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
device=q_nope.device)
|
||||
|
||||
graph_params.attn_params[num_tokens].append(
|
||||
(q_nope, k_nope, q_pe, k_pe, self.num_heads, self.num_kv_heads,
|
||||
input_layout, spec_attn_mask, sparse_mode, self.scale,
|
||||
decode_meta.block_table, block_size,
|
||||
decode_meta.seq_lens_list, actual_seq_lengths, workspace,
|
||||
attn_output, softmax_lse))
|
||||
(weak_ref_tensors(q_nope), weak_ref_tensors(k_nope),
|
||||
weak_ref_tensors(q_pe), weak_ref_tensors(k_pe),
|
||||
self.num_heads, self.num_kv_heads, input_layout,
|
||||
weak_ref_tensors(spec_attn_mask) if spec_attn_mask is not None
|
||||
else None, sparse_mode, self.scale, decode_meta.block_table,
|
||||
block_size, decode_meta.seq_lens_list, actual_seq_lengths,
|
||||
weak_ref_tensors(attn_output), weak_ref_tensors(softmax_lse)))
|
||||
|
||||
torch.npu.graph_task_group_begin(stream)
|
||||
torch_npu.npu_fused_infer_attention_score.out(
|
||||
|
||||
Reference in New Issue
Block a user