[Bugfix][DP] Add with_prefill_across_dp to AscendMetadata to fix dp (#1094)

### What this PR does / why we need it?
Add `with_prefill_across_dp` to AscendMetadata to fix dp

This pr fixes the bug introduced by #1012, which add an arg
`with_prefill_across_dp` when dp_size > 1.

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2025-06-06 19:20:33 +08:00
committed by GitHub
parent 0b12c2acf7
commit c46632439a

View File

@@ -132,6 +132,8 @@ class AscendMetadata:
# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.
with_prefill_across_dp: bool = False
class AscendAttentionMetadataBuilder:
@@ -142,8 +144,12 @@ class AscendAttentionMetadataBuilder:
scheduler_output: "SchedulerOutput") -> bool:
return False
def build(self, num_reqs, num_actual_tokens, max_query_len,
common_prefix_len):
def build(self,
num_reqs,
num_actual_tokens,
max_query_len,
common_prefix_len,
with_prefill_across_dp: bool = False):
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
)
@@ -160,15 +166,17 @@ class AscendAttentionMetadataBuilder:
query_start_loc = query_start_loc_cpu.to(self.runner.device,
non_blocking=True)
attn_metadata = AscendMetadata(num_actual_tokens=num_actual_tokens,
block_tables=block_table,
query_start_loc=query_start_loc,
query_lens=query_lens,
seq_lens=seq_lens,
max_query_len=max_query_len,
slot_mapping=slot_mapping,
attn_mask=attn_mask,
attn_state=attn_state)
attn_metadata = AscendMetadata(
num_actual_tokens=num_actual_tokens,
block_tables=block_table,
query_start_loc=query_start_loc,
query_lens=query_lens,
seq_lens=seq_lens,
max_query_len=max_query_len,
slot_mapping=slot_mapping,
attn_mask=attn_mask,
attn_state=attn_state,
with_prefill_across_dp=with_prefill_across_dp)
return attn_metadata