[DP] Tiny fix of dp and update example (#1273)
### What this PR does / why we need it? Add `max_num_tokens_across_dp` to AscendMetadata to fix dp This pr fixes the bug introduced by https://github.com/vllm-project/vllm-ascend/pull/1229, which add an arg `max_num_tokens_across_dp` when dp_size > 1. Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -119,6 +119,10 @@ class AscendMetadata:
|
||||
query_start_loc: torch.Tensor
|
||||
query_lens: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
# max value of number of tokens across dp group
|
||||
max_num_tokens_across_dp: int = 0
|
||||
|
||||
# Maximum query length in the batch. None for decoding.
|
||||
max_query_len: Optional[int] = None
|
||||
# (num_tokens,). The indices of the token slots that input tokens will be
|
||||
@@ -155,6 +159,7 @@ class AscendAttentionMetadataBuilder:
|
||||
num_actual_tokens,
|
||||
max_query_len,
|
||||
common_prefix_len,
|
||||
max_num_tokens_across_dp: int = 0,
|
||||
with_prefill_across_dp: bool = False):
|
||||
|
||||
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
|
||||
@@ -192,6 +197,7 @@ class AscendAttentionMetadataBuilder:
|
||||
slot_mapping=slot_mapping,
|
||||
attn_mask=attn_mask,
|
||||
attn_state=attn_state,
|
||||
max_num_tokens_across_dp=max_num_tokens_across_dp,
|
||||
with_prefill_across_dp=with_prefill_across_dp)
|
||||
return attn_metadata
|
||||
|
||||
|
||||
Reference in New Issue
Block a user