### What this PR does / why we need it?
Cherry pick #1291 from v0.9.1-dev, This pr implement the synchronization
of whether `dbo` is enabled across all dp ranks. specifically, it
performed allreduce op across multiple DP ranks, only when all the dp
rank is `enable_dbo`, it is enabled
Co-authored-by: shikang-hangzhou <459956190@qq.com>
Co-authored-by: wangli <wangli858794774@gmail.com>
- vLLM version: v0.10.0
- vLLM main:
2836dd73f1
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
@@ -150,6 +150,8 @@ class AscendMetadata:
|
||||
# (num_tokens,)
|
||||
slot_mapping: torch.Tensor = None
|
||||
|
||||
enable_dbo_across_dp: bool = False
|
||||
|
||||
|
||||
class AscendAttentionMetadataBuilder:
|
||||
|
||||
@@ -160,7 +162,11 @@ class AscendAttentionMetadataBuilder:
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
return False
|
||||
|
||||
def build(self, num_reqs, num_actual_tokens, max_query_len):
|
||||
def build(self,
|
||||
num_reqs,
|
||||
num_actual_tokens,
|
||||
max_query_len,
|
||||
enable_dbo_across_dp: bool = False):
|
||||
|
||||
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
|
||||
)
|
||||
@@ -187,15 +193,17 @@ class AscendAttentionMetadataBuilder:
|
||||
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
attn_metadata = AscendMetadata(num_actual_tokens=num_actual_tokens,
|
||||
block_tables=block_table,
|
||||
query_start_loc=query_start_loc,
|
||||
query_lens=query_lens,
|
||||
seq_lens=seq_lens,
|
||||
max_query_len=max_query_len,
|
||||
slot_mapping=slot_mapping,
|
||||
attn_mask=attn_mask,
|
||||
attn_state=attn_state)
|
||||
attn_metadata = AscendMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
block_tables=block_table,
|
||||
query_start_loc=query_start_loc,
|
||||
query_lens=query_lens,
|
||||
seq_lens=seq_lens,
|
||||
max_query_len=max_query_len,
|
||||
slot_mapping=slot_mapping,
|
||||
attn_mask=attn_mask,
|
||||
attn_state=attn_state,
|
||||
enable_dbo_across_dp=enable_dbo_across_dp)
|
||||
return attn_metadata
|
||||
|
||||
|
||||
|
||||
@@ -140,6 +140,8 @@ class AscendTorchairMetadata:
|
||||
|
||||
decode: Optional[AscendDecodeMetadata] = None
|
||||
|
||||
enable_dbo_across_dp: bool = False
|
||||
|
||||
|
||||
class AscendAttentionTorchairMetadataBuilder:
|
||||
|
||||
@@ -220,7 +222,8 @@ class AscendAttentionTorchairMetadataBuilder:
|
||||
num_reqs,
|
||||
num_actual_tokens,
|
||||
max_query_len,
|
||||
graph_pad_size: int = -1):
|
||||
graph_pad_size: int = -1,
|
||||
enable_dbo_across_dp: bool = False):
|
||||
|
||||
device = self.runner.device
|
||||
|
||||
@@ -298,7 +301,8 @@ class AscendAttentionTorchairMetadataBuilder:
|
||||
max_query_len=max_query_len,
|
||||
slot_mapping=slot_mapping,
|
||||
attn_mask=attn_mask,
|
||||
attn_state=attn_state)
|
||||
attn_state=attn_state,
|
||||
enable_dbo_across_dp=enable_dbo_across_dp)
|
||||
return attn_metadata
|
||||
|
||||
|
||||
|
||||
@@ -137,6 +137,7 @@ class AscendMLAMetadata:
|
||||
|
||||
decode: Optional[AscendMLADecodeMetadata] = None
|
||||
prefill: Optional[AscendMLAPrefillMetadata] = None
|
||||
enable_dbo_across_dp: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
pass
|
||||
@@ -370,6 +371,7 @@ class AscendMLAMetadataBuilder:
|
||||
max_query_len: int,
|
||||
graph_pad_size: int = -1,
|
||||
query_start_loc: torch.Tensor = None,
|
||||
enable_dbo_across_dp: bool = False,
|
||||
) -> AscendMLAMetadata:
|
||||
assert self._num_decodes + self._num_prefills == num_reqs
|
||||
|
||||
@@ -536,6 +538,7 @@ class AscendMLAMetadataBuilder:
|
||||
query_start_loc=query_start_loc,
|
||||
block_tables=block_table,
|
||||
seq_lens=seq_lens,
|
||||
enable_dbo_across_dp=enable_dbo_across_dp,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user