[Bugfix][DP] Add with_prefill_across_dp to AscendMetadata to fix dp (#1094)

### What this PR does / why we need it?
Add `with_prefill_across_dp` to AscendMetadata to fix dp

This pr fixes the bug introduced by #1012, which add an arg
`with_prefill_across_dp` when dp_size > 1.

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2025-06-06 19:20:33 +08:00
committed by GitHub
parent 0b12c2acf7
commit c46632439a

View File

@@ -132,6 +132,8 @@ class AscendMetadata:
# For logging. # For logging.
num_input_tokens: int = 0 # Number of tokens including padding. num_input_tokens: int = 0 # Number of tokens including padding.
with_prefill_across_dp: bool = False
class AscendAttentionMetadataBuilder: class AscendAttentionMetadataBuilder:
@@ -142,8 +144,12 @@ class AscendAttentionMetadataBuilder:
scheduler_output: "SchedulerOutput") -> bool: scheduler_output: "SchedulerOutput") -> bool:
return False return False
def build(self, num_reqs, num_actual_tokens, max_query_len, def build(self,
common_prefix_len): num_reqs,
num_actual_tokens,
max_query_len,
common_prefix_len,
with_prefill_across_dp: bool = False):
block_table = self.runner.input_batch.block_table[0].get_device_tensor( block_table = self.runner.input_batch.block_table[0].get_device_tensor(
) )
@@ -160,15 +166,17 @@ class AscendAttentionMetadataBuilder:
query_start_loc = query_start_loc_cpu.to(self.runner.device, query_start_loc = query_start_loc_cpu.to(self.runner.device,
non_blocking=True) non_blocking=True)
attn_metadata = AscendMetadata(num_actual_tokens=num_actual_tokens, attn_metadata = AscendMetadata(
block_tables=block_table, num_actual_tokens=num_actual_tokens,
query_start_loc=query_start_loc, block_tables=block_table,
query_lens=query_lens, query_start_loc=query_start_loc,
seq_lens=seq_lens, query_lens=query_lens,
max_query_len=max_query_len, seq_lens=seq_lens,
slot_mapping=slot_mapping, max_query_len=max_query_len,
attn_mask=attn_mask, slot_mapping=slot_mapping,
attn_state=attn_state) attn_mask=attn_mask,
attn_state=attn_state,
with_prefill_across_dp=with_prefill_across_dp)
return attn_metadata return attn_metadata