[BugFix] Fix max_num_tokens_across_dp calculation bugs in attention_v1_torchair (#1636)
### What this PR does / why we need it? This PR fixes a bug that is caused by max_num_tokens_across_dp calculation. In earlier version, we compute this by graph_pad_size plus max_num_tokens(actual). This will result in different max_num_tokens_across_dp across dp ranks. If padding related is required, this might cause a wrong padding. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed normally. Signed-off-by: angazenn <zengyanjia@huawei.com> Co-authored-by: angazenn <zengyanjia@huawei.com>
This commit is contained in:
@@ -273,10 +273,10 @@ class AscendAttentionTorchairMetadataBuilder:
|
||||
if use_torchair_graph and self.runner.attn_state in [
|
||||
AscendAttentionState.DecodeOnly,
|
||||
]:
|
||||
max_num_tokens_across_dp += graph_pad_size
|
||||
pad_value = 1
|
||||
padded_seq_lens = seq_lens.tolist() + [pad_value
|
||||
] * graph_pad_size
|
||||
max_num_tokens_across_dp = len(padded_seq_lens)
|
||||
|
||||
seq_lens = torch.from_numpy(
|
||||
np.array(padded_seq_lens).astype(np.int32))
|
||||
|
||||
Reference in New Issue
Block a user