[v0.18.0] fix(attention): reuse weight address in graph + RL scenario (#7715)
### What this PR does / why we need it? In graph + RL scenario, we only capture the graph once, and the weight address is expected to be the same across iterations. However, when calling .contiguous() on weight tensors, a new memory address may be allocated, causing the graph to capture incorrect weight addresses. This PR modifies the weight update logic in AscendMLAImpl and AscendSFAImpl to use copy_() instead of reassignment, ensuring the weight addresses remain consistent across iterations. detailed in #7473 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Signed-off-by: Debonex <719893090@qq.com>
This commit is contained in:
@@ -884,10 +884,17 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
|
||||
W_UK, W_UV = kv_b_proj_weight.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
# Convert from (L, N, V) to (N, L, V)
|
||||
self.W_UV = W_UV.transpose(0, 1).contiguous()
|
||||
# Convert from (L, N, P) to (N, P, L)
|
||||
self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()
|
||||
# NOTE: When we make a incontiguous weight contiguous, a new address will be allocated for the weight,
|
||||
# in graph + RL scenario, we only capture the graph once, and the weight address is expected to be the same
|
||||
# across iterations, so we need to copy the weight to the original address after making it contiguous.
|
||||
if not hasattr(self, "W_UV"):
|
||||
# Convert from (L, N, V) to (N, L, V)
|
||||
self.W_UV = W_UV.transpose(0, 1).contiguous()
|
||||
# Convert from (L, N, P) to (N, P, L)
|
||||
self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()
|
||||
else:
|
||||
self.W_UV.copy_(W_UV.transpose(0, 1).contiguous())
|
||||
self.W_UK_T.copy_(W_UK.permute(1, 2, 0).contiguous())
|
||||
|
||||
# TODO(zzzzwwjj): Currently, torch.ops._C_ascend.batch_matmul_transpose cannot support weight nz
|
||||
# self.W_UV = maybe_trans_nz(self.W_UV)
|
||||
|
||||
@@ -483,10 +483,17 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
|
||||
W_UK, W_UV = kv_b_proj_weight.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
# Convert from (L, N, V) to (N, L, V)
|
||||
self.W_UV = W_UV.transpose(0, 1).contiguous()
|
||||
# Convert from (L, N, P) to (N, P, L)
|
||||
self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()
|
||||
# NOTE: When we make a incontiguous weight contiguous, a new address will be allocated for the weight,
|
||||
# in graph + RL scenario, we only capture the graph once, and the weight address is expected to be the same
|
||||
# across iterations, so we need to copy the weight to the original address after making it contiguous.
|
||||
if not hasattr(self, "W_UV"):
|
||||
# Convert from (L, N, V) to (N, L, V)
|
||||
self.W_UV = W_UV.transpose(0, 1).contiguous()
|
||||
# Convert from (L, N, P) to (N, P, L)
|
||||
self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()
|
||||
else:
|
||||
self.W_UV.copy_(W_UV.transpose(0, 1).contiguous())
|
||||
self.W_UK_T.copy_(W_UK.permute(1, 2, 0).contiguous())
|
||||
|
||||
# TODO(zzzzwwjj): Currently, torch.ops._C_ascend.batch_matmul_transpose cannot support weight nz
|
||||
# self.W_UV = maybe_trans_nz(self.W_UV)
|
||||
|
||||
Reference in New Issue
Block a user