[MISC] Cherry pick #1291 from v0.9.1-dev (#1825)

### What this PR does / why we need it?
Cherry pick #1291 from v0.9.1-dev, This pr implement the synchronization
of whether `dbo` is enabled across all dp ranks. specifically, it
performed allreduce op across multiple DP ranks, only when all the dp
rank is `enable_dbo`, it is enabled

Co-authored-by: shikang-hangzhou <459956190@qq.com>
Co-authored-by: wangli <wangli858794774@gmail.com>

- vLLM version: v0.10.0
- vLLM main:
2836dd73f1

---------

Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
Li Wang
2025-08-01 09:08:45 +08:00
committed by GitHub
parent 9e65da990e
commit 2284289880
6 changed files with 68 additions and 37 deletions

View File

@@ -150,6 +150,8 @@ class AscendMetadata:
# (num_tokens,)
slot_mapping: torch.Tensor = None
enable_dbo_across_dp: bool = False
class AscendAttentionMetadataBuilder:
@@ -160,7 +162,11 @@ class AscendAttentionMetadataBuilder:
scheduler_output: "SchedulerOutput") -> bool:
return False
def build(self, num_reqs, num_actual_tokens, max_query_len):
def build(self,
num_reqs,
num_actual_tokens,
max_query_len,
enable_dbo_across_dp: bool = False):
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
)
@@ -187,15 +193,17 @@ class AscendAttentionMetadataBuilder:
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
ACL_FORMAT_FRACTAL_NZ)
attn_metadata = AscendMetadata(num_actual_tokens=num_actual_tokens,
block_tables=block_table,
query_start_loc=query_start_loc,
query_lens=query_lens,
seq_lens=seq_lens,
max_query_len=max_query_len,
slot_mapping=slot_mapping,
attn_mask=attn_mask,
attn_state=attn_state)
attn_metadata = AscendMetadata(
num_actual_tokens=num_actual_tokens,
block_tables=block_table,
query_start_loc=query_start_loc,
query_lens=query_lens,
seq_lens=seq_lens,
max_query_len=max_query_len,
slot_mapping=slot_mapping,
attn_mask=attn_mask,
attn_state=attn_state,
enable_dbo_across_dp=enable_dbo_across_dp)
return attn_metadata