[Bugfix][DP] Add with_prefill_across_dp to AscendMetadata to fix dp (#1094)
### What this PR does / why we need it? Add `with_prefill_across_dp` to AscendMetadata to fix dp This pr fixes the bug introduced by #1012, which add an arg `with_prefill_across_dp` when dp_size > 1. Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -132,6 +132,8 @@ class AscendMetadata:
|
|||||||
# For logging.
|
# For logging.
|
||||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
num_input_tokens: int = 0 # Number of tokens including padding.
|
||||||
|
|
||||||
|
with_prefill_across_dp: bool = False
|
||||||
|
|
||||||
|
|
||||||
class AscendAttentionMetadataBuilder:
|
class AscendAttentionMetadataBuilder:
|
||||||
|
|
||||||
@@ -142,8 +144,12 @@ class AscendAttentionMetadataBuilder:
|
|||||||
scheduler_output: "SchedulerOutput") -> bool:
|
scheduler_output: "SchedulerOutput") -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def build(self, num_reqs, num_actual_tokens, max_query_len,
|
def build(self,
|
||||||
common_prefix_len):
|
num_reqs,
|
||||||
|
num_actual_tokens,
|
||||||
|
max_query_len,
|
||||||
|
common_prefix_len,
|
||||||
|
with_prefill_across_dp: bool = False):
|
||||||
|
|
||||||
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
|
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
|
||||||
)
|
)
|
||||||
@@ -160,7 +166,8 @@ class AscendAttentionMetadataBuilder:
|
|||||||
query_start_loc = query_start_loc_cpu.to(self.runner.device,
|
query_start_loc = query_start_loc_cpu.to(self.runner.device,
|
||||||
non_blocking=True)
|
non_blocking=True)
|
||||||
|
|
||||||
attn_metadata = AscendMetadata(num_actual_tokens=num_actual_tokens,
|
attn_metadata = AscendMetadata(
|
||||||
|
num_actual_tokens=num_actual_tokens,
|
||||||
block_tables=block_table,
|
block_tables=block_table,
|
||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
query_lens=query_lens,
|
query_lens=query_lens,
|
||||||
@@ -168,7 +175,8 @@ class AscendAttentionMetadataBuilder:
|
|||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
attn_state=attn_state)
|
attn_state=attn_state,
|
||||||
|
with_prefill_across_dp=with_prefill_across_dp)
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user