[v0.18.0] fix(attention): reuse weight address in graph + RL scenario (#7715)

### What this PR does / why we need it?

In graph + RL scenario, we only capture the graph once, and the weight
address is expected to be the same across iterations. However, when
calling .contiguous() on weight tensors, a new memory address may be
allocated, causing the graph to capture incorrect weight addresses.
This PR modifies the weight update logic in AscendMLAImpl and
AscendSFAImpl to use copy_() instead of reassignment, ensuring the
weight addresses remain consistent across iterations.

detailed in #7473

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

Signed-off-by: Debonex <719893090@qq.com>
This commit is contained in:
Debonet
2026-03-27 14:11:20 +08:00
committed by GitHub
parent 29308ac3a9
commit 6ce1dc162a
2 changed files with 22 additions and 8 deletions

View File

@@ -884,10 +884,17 @@ class AscendMLAImpl(MLAAttentionImpl):
W_UK, W_UV = kv_b_proj_weight.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
# Convert from (L, N, V) to (N, L, V)
self.W_UV = W_UV.transpose(0, 1).contiguous()
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()
# NOTE: When we make a incontiguous weight contiguous, a new address will be allocated for the weight,
# in graph + RL scenario, we only capture the graph once, and the weight address is expected to be the same
# across iterations, so we need to copy the weight to the original address after making it contiguous.
if not hasattr(self, "W_UV"):
# Convert from (L, N, V) to (N, L, V)
self.W_UV = W_UV.transpose(0, 1).contiguous()
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()
else:
self.W_UV.copy_(W_UV.transpose(0, 1).contiguous())
self.W_UK_T.copy_(W_UK.permute(1, 2, 0).contiguous())
# TODO(zzzzwwjj): Currently, torch.ops._C_ascend.batch_matmul_transpose cannot support weight nz
# self.W_UV = maybe_trans_nz(self.W_UV)

View File

@@ -483,10 +483,17 @@ class AscendSFAImpl(MLAAttentionImpl):
W_UK, W_UV = kv_b_proj_weight.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
# Convert from (L, N, V) to (N, L, V)
self.W_UV = W_UV.transpose(0, 1).contiguous()
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()
# NOTE: When we make a incontiguous weight contiguous, a new address will be allocated for the weight,
# in graph + RL scenario, we only capture the graph once, and the weight address is expected to be the same
# across iterations, so we need to copy the weight to the original address after making it contiguous.
if not hasattr(self, "W_UV"):
# Convert from (L, N, V) to (N, L, V)
self.W_UV = W_UV.transpose(0, 1).contiguous()
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()
else:
self.W_UV.copy_(W_UV.transpose(0, 1).contiguous())
self.W_UK_T.copy_(W_UK.permute(1, 2, 0).contiguous())
# TODO(zzzzwwjj): Currently, torch.ops._C_ascend.batch_matmul_transpose cannot support weight nz
# self.W_UV = maybe_trans_nz(self.W_UV)