[BugFix] Fix max_num_tokens_across_dp calculation bugs in attention_v1_torchair (#1636)
### What this PR does / why we need it? This PR fixes a bug that is caused by max_num_tokens_across_dp calculation. In earlier version, we compute this by graph_pad_size plus max_num_tokens(actual). This will result in different max_num_tokens_across_dp across dp ranks. If padding related is required, this might cause a wrong padding. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed normally. Signed-off-by: angazenn <zengyanjia@huawei.com> Co-authored-by: angazenn <zengyanjia@huawei.com>
This commit is contained in:
@@ -273,10 +273,10 @@ class AscendAttentionTorchairMetadataBuilder:
|
|||||||
if use_torchair_graph and self.runner.attn_state in [
|
if use_torchair_graph and self.runner.attn_state in [
|
||||||
AscendAttentionState.DecodeOnly,
|
AscendAttentionState.DecodeOnly,
|
||||||
]:
|
]:
|
||||||
max_num_tokens_across_dp += graph_pad_size
|
|
||||||
pad_value = 1
|
pad_value = 1
|
||||||
padded_seq_lens = seq_lens.tolist() + [pad_value
|
padded_seq_lens = seq_lens.tolist() + [pad_value
|
||||||
] * graph_pad_size
|
] * graph_pad_size
|
||||||
|
max_num_tokens_across_dp = len(padded_seq_lens)
|
||||||
|
|
||||||
seq_lens = torch.from_numpy(
|
seq_lens = torch.from_numpy(
|
||||||
np.array(padded_seq_lens).astype(np.int32))
|
np.array(padded_seq_lens).astype(np.int32))
|
||||||
|
|||||||
Reference in New Issue
Block a user