### What this PR does / why we need it? The cache for MLA decode graph parameters was holding strong references to tensors, preventing them from being garbage collected and leading to increased memory usage. This change wraps the cached tensors in weak references, allowing them to be deallocated when no longer in use and reducing overall memory pressure. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? None. Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -443,7 +443,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output)
|
||||
update_graph_params_workspaces(num_tokens, workspace)
|
||||
update_graph_params_workspaces(
|
||||
num_tokens, weak_ref_tensors(workspace))
|
||||
|
||||
# Handle graph capturing mode
|
||||
stream = torch_npu.npu.current_stream()
|
||||
@@ -459,7 +460,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
self.num_kv_heads,
|
||||
self.num_heads,
|
||||
self.scale,
|
||||
weak_ref_tensors(attn_metadata.block_tables),
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.seq_lens,
|
||||
weak_ref_tensors(output),
|
||||
))
|
||||
|
||||
Reference in New Issue
Block a user