From 6ce1dc162a9bf464cc6074b98a500d97cd1d624f Mon Sep 17 00:00:00 2001 From: Debonet <37174444+Debonex@users.noreply.github.com> Date: Fri, 27 Mar 2026 14:11:20 +0800 Subject: [PATCH] [v0.18.0] fix(attention): reuse weight address in graph + RL scenario (#7715) ### What this PR does / why we need it? In graph + RL scenario, we only capture the graph once, and the weight address is expected to be the same across iterations. However, when calling .contiguous() on weight tensors, a new memory address may be allocated, causing the graph to capture incorrect weight addresses. This PR modifies the weight update logic in AscendMLAImpl and AscendSFAImpl to use copy_() instead of reassignment, ensuring the weight addresses remain consistent across iterations. detailed in #7473 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Signed-off-by: Debonex <719893090@qq.com> --- vllm_ascend/attention/mla_v1.py | 15 +++++++++++---- vllm_ascend/attention/sfa_v1.py | 15 +++++++++++---- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 06108660..1b5a9457 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -884,10 +884,17 @@ class AscendMLAImpl(MLAAttentionImpl): W_UK, W_UV = kv_b_proj_weight.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - # Convert from (L, N, V) to (N, L, V) - self.W_UV = W_UV.transpose(0, 1).contiguous() - # Convert from (L, N, P) to (N, P, L) - self.W_UK_T = W_UK.permute(1, 2, 0).contiguous() + # NOTE: When we make a incontiguous weight contiguous, a new address will be allocated for the weight, + # in graph + RL scenario, we only capture the graph once, and the weight address is expected to be the same + # across iterations, so we need to copy the weight to the original address after making it contiguous. + if not hasattr(self, "W_UV"): + # Convert from (L, N, V) to (N, L, V) + self.W_UV = W_UV.transpose(0, 1).contiguous() + # Convert from (L, N, P) to (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0).contiguous() + else: + self.W_UV.copy_(W_UV.transpose(0, 1).contiguous()) + self.W_UK_T.copy_(W_UK.permute(1, 2, 0).contiguous()) # TODO(zzzzwwjj): Currently, torch.ops._C_ascend.batch_matmul_transpose cannot support weight nz # self.W_UV = maybe_trans_nz(self.W_UV) diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 7d787648..0aec273b 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -483,10 +483,17 @@ class AscendSFAImpl(MLAAttentionImpl): W_UK, W_UV = kv_b_proj_weight.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - # Convert from (L, N, V) to (N, L, V) - self.W_UV = W_UV.transpose(0, 1).contiguous() - # Convert from (L, N, P) to (N, P, L) - self.W_UK_T = W_UK.permute(1, 2, 0).contiguous() + # NOTE: When we make a incontiguous weight contiguous, a new address will be allocated for the weight, + # in graph + RL scenario, we only capture the graph once, and the weight address is expected to be the same + # across iterations, so we need to copy the weight to the original address after making it contiguous. + if not hasattr(self, "W_UV"): + # Convert from (L, N, V) to (N, L, V) + self.W_UV = W_UV.transpose(0, 1).contiguous() + # Convert from (L, N, P) to (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0).contiguous() + else: + self.W_UV.copy_(W_UV.transpose(0, 1).contiguous()) + self.W_UK_T.copy_(W_UK.permute(1, 2, 0).contiguous()) # TODO(zzzzwwjj): Currently, torch.ops._C_ascend.batch_matmul_transpose cannot support weight nz # self.W_UV = maybe_trans_nz(self.W_UV)