diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 06108660..1b5a9457 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -884,10 +884,17 @@ class AscendMLAImpl(MLAAttentionImpl): W_UK, W_UV = kv_b_proj_weight.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - # Convert from (L, N, V) to (N, L, V) - self.W_UV = W_UV.transpose(0, 1).contiguous() - # Convert from (L, N, P) to (N, P, L) - self.W_UK_T = W_UK.permute(1, 2, 0).contiguous() + # NOTE: When we make a incontiguous weight contiguous, a new address will be allocated for the weight, + # in graph + RL scenario, we only capture the graph once, and the weight address is expected to be the same + # across iterations, so we need to copy the weight to the original address after making it contiguous. + if not hasattr(self, "W_UV"): + # Convert from (L, N, V) to (N, L, V) + self.W_UV = W_UV.transpose(0, 1).contiguous() + # Convert from (L, N, P) to (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0).contiguous() + else: + self.W_UV.copy_(W_UV.transpose(0, 1).contiguous()) + self.W_UK_T.copy_(W_UK.permute(1, 2, 0).contiguous()) # TODO(zzzzwwjj): Currently, torch.ops._C_ascend.batch_matmul_transpose cannot support weight nz # self.W_UV = maybe_trans_nz(self.W_UV) diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 7d787648..0aec273b 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -483,10 +483,17 @@ class AscendSFAImpl(MLAAttentionImpl): W_UK, W_UV = kv_b_proj_weight.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - # Convert from (L, N, V) to (N, L, V) - self.W_UV = W_UV.transpose(0, 1).contiguous() - # Convert from (L, N, P) to (N, P, L) - self.W_UK_T = W_UK.permute(1, 2, 0).contiguous() + # NOTE: When we make a incontiguous weight contiguous, a new address will be allocated for the weight, + # in graph + RL scenario, we only capture the graph once, and the weight address is expected to be the same + # across iterations, so we need to copy the weight to the original address after making it contiguous. + if not hasattr(self, "W_UV"): + # Convert from (L, N, V) to (N, L, V) + self.W_UV = W_UV.transpose(0, 1).contiguous() + # Convert from (L, N, P) to (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0).contiguous() + else: + self.W_UV.copy_(W_UV.transpose(0, 1).contiguous()) + self.W_UK_T.copy_(W_UK.permute(1, 2, 0).contiguous()) # TODO(zzzzwwjj): Currently, torch.ops._C_ascend.batch_matmul_transpose cannot support weight nz # self.W_UV = maybe_trans_nz(self.W_UV)