[MISC] Cherry pick #1291 from v0.9.1-dev (#1825)

### What this PR does / why we need it?
Cherry pick #1291 from v0.9.1-dev, This pr implement the synchronization
of whether `dbo` is enabled across all dp ranks. specifically, it
performed allreduce op across multiple DP ranks, only when all the dp
rank is `enable_dbo`, it is enabled

Co-authored-by: shikang-hangzhou <459956190@qq.com>
Co-authored-by: wangli <wangli858794774@gmail.com>

- vLLM version: v0.10.0
- vLLM main:
2836dd73f1

---------

Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
Li Wang
2025-08-01 09:08:45 +08:00
committed by GitHub
parent 9e65da990e
commit 2284289880
6 changed files with 68 additions and 37 deletions

View File

@@ -150,6 +150,8 @@ class AscendMetadata:
# (num_tokens,)
slot_mapping: torch.Tensor = None
enable_dbo_across_dp: bool = False
class AscendAttentionMetadataBuilder:
@@ -160,7 +162,11 @@ class AscendAttentionMetadataBuilder:
scheduler_output: "SchedulerOutput") -> bool:
return False
def build(self, num_reqs, num_actual_tokens, max_query_len):
def build(self,
num_reqs,
num_actual_tokens,
max_query_len,
enable_dbo_across_dp: bool = False):
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
)
@@ -187,15 +193,17 @@ class AscendAttentionMetadataBuilder:
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
ACL_FORMAT_FRACTAL_NZ)
attn_metadata = AscendMetadata(num_actual_tokens=num_actual_tokens,
block_tables=block_table,
query_start_loc=query_start_loc,
query_lens=query_lens,
seq_lens=seq_lens,
max_query_len=max_query_len,
slot_mapping=slot_mapping,
attn_mask=attn_mask,
attn_state=attn_state)
attn_metadata = AscendMetadata(
num_actual_tokens=num_actual_tokens,
block_tables=block_table,
query_start_loc=query_start_loc,
query_lens=query_lens,
seq_lens=seq_lens,
max_query_len=max_query_len,
slot_mapping=slot_mapping,
attn_mask=attn_mask,
attn_state=attn_state,
enable_dbo_across_dp=enable_dbo_across_dp)
return attn_metadata

View File

@@ -140,6 +140,8 @@ class AscendTorchairMetadata:
decode: Optional[AscendDecodeMetadata] = None
enable_dbo_across_dp: bool = False
class AscendAttentionTorchairMetadataBuilder:
@@ -220,7 +222,8 @@ class AscendAttentionTorchairMetadataBuilder:
num_reqs,
num_actual_tokens,
max_query_len,
graph_pad_size: int = -1):
graph_pad_size: int = -1,
enable_dbo_across_dp: bool = False):
device = self.runner.device
@@ -298,7 +301,8 @@ class AscendAttentionTorchairMetadataBuilder:
max_query_len=max_query_len,
slot_mapping=slot_mapping,
attn_mask=attn_mask,
attn_state=attn_state)
attn_state=attn_state,
enable_dbo_across_dp=enable_dbo_across_dp)
return attn_metadata

View File

@@ -137,6 +137,7 @@ class AscendMLAMetadata:
decode: Optional[AscendMLADecodeMetadata] = None
prefill: Optional[AscendMLAPrefillMetadata] = None
enable_dbo_across_dp: bool = False
def __post_init__(self):
pass
@@ -370,6 +371,7 @@ class AscendMLAMetadataBuilder:
max_query_len: int,
graph_pad_size: int = -1,
query_start_loc: torch.Tensor = None,
enable_dbo_across_dp: bool = False,
) -> AscendMLAMetadata:
assert self._num_decodes + self._num_prefills == num_reqs
@@ -536,6 +538,7 @@ class AscendMLAMetadataBuilder:
query_start_loc=query_start_loc,
block_tables=block_table,
seq_lens=seq_lens,
enable_dbo_across_dp=enable_dbo_across_dp,
)

View File

@@ -75,7 +75,6 @@ from vllm_ascend.multistream.layers import (MultiStreamPostTransformerLayer,
from vllm_ascend.multistream.metadata import (MultiStreamConfig,
MultiStreamStepMetadata,
make_multistream_metadata_ds)
from vllm_ascend.multistream.ms_split import compute_split_seq_index
from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.utils import dispose_tensor
@@ -872,24 +871,9 @@ class CustomDeepseekDBOModel(nn.Module):
def can_run_ms(self):
attn_metadata = get_forward_context().attn_metadata
# support mla attention and V1 engine at present
if not self.use_mla:
return False
# enable prefill overlap
if attn_metadata is None or attn_metadata.num_prefills == 0:
return False
else:
[token_index, seq_index
] = compute_split_seq_index(attn_metadata.query_lens,
attn_metadata.attn_state,
attn_metadata.num_decode_tokens)
if token_index == 0 or seq_index == 0 or seq_index == len(
attn_metadata.query_lens):
return False
# check whether the total tokens exceed the threshold
if self.multistream_config is None or attn_metadata.num_actual_tokens < self.multistream_config.min_total_tokens_to_split:
return False
return True
return not (attn_metadata is None or attn_metadata.num_prefills == 0
or not attn_metadata.enable_dbo_across_dp)
def _forward_ms_layers(
self,

View File

@@ -96,10 +96,12 @@ def model_input_split_v1_mla_attn(
seq_lens = attn_metadata.prefill.seq_lens if attn_metadata.num_prefills > 0 else attn_metadata.decode.seq_lens
[seq_lens_pre, seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index)
query_start_loc_pre = attn_metadata.query_start_loc[:seq_index + 1]
query_start_loc_post = deepcopy(
attn_metadata.query_start_loc[seq_index:]
) - attn_metadata.query_start_loc[seq_index]
query_start_loc_pre = query_start_loc_post = None
if attn_metadata.query_start_loc is not None:
query_start_loc_pre = attn_metadata.query_start_loc[:seq_index + 1]
query_start_loc_post = deepcopy(
attn_metadata.query_start_loc[seq_index:]
) - attn_metadata.query_start_loc[seq_index]
[block_table_pre,
block_table_post] = split_attn_tensor_type(attn_metadata.block_tables,
seq_index)
@@ -223,6 +225,7 @@ def model_input_split_v1_mla_attn(
attn_mask=attn_mask_pre,
prefill=prefill_pre,
decode=decode_pre,
enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp,
)
attention_metadata_post = _metadata_cls(
num_actual_tokens=attn_metadata.num_actual_tokens - token_index,
@@ -239,5 +242,6 @@ def model_input_split_v1_mla_attn(
attn_state=attn_state_post,
prefill=prefill_post,
decode=decode_post,
enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp,
)
return [attention_metadata_pre, attention_metadata_post]

View File

@@ -79,6 +79,7 @@ from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
AscendMetadata)
from vllm_ascend.attention.attention_v1_torchair import AscendTorchairMetadata
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
from vllm_ascend.multistream.ms_split import compute_split_seq_index
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
@@ -606,6 +607,27 @@ class NPUModelRunner(LoRAModelRunnerMixin):
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, not bool(
forward_metadata[-1])
def _check_dbo_is_valid(self, query_lens: torch.Tensor,
attn_state: AscendAttentionState,
num_tokens: int) -> bool:
# do the checks for dp + dbo
if attn_state in [
AscendAttentionState.DecodeOnly,
AscendAttentionState.SpecDecoding
]:
return False
# considering the case that one dp rank may enable dbo while others may not
if not self.vllm_config.model_config.use_mla or not envs_ascend.VLLM_ASCEND_ENABLE_DBO:
return False
# TODO: remove it if token-level microbatch is enabled
[token_index,
seq_index] = compute_split_seq_index(query_lens, attn_state,
num_tokens)
if token_index == 0 or seq_index == 0 or seq_index == len(
query_lens) or num_tokens < 256:
return False
return True
def get_eagle_atten_dict(
self,
scheduler_output: "SchedulerOutput",
@@ -1080,6 +1102,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
with_prefill = attn_state not in [
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
]
enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(),
attn_state,
total_num_scheduled_tokens)
maybe_padded_num_tokens = total_num_scheduled_tokens
if self.torchair_graph_enabled and not with_prefill:
@@ -1087,7 +1112,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
total_num_scheduled_tokens)
(padded_num_tokens_across_dp, num_tokens_across_dp, with_prefill,
enable_dbo) = self._get_forward_metadata_across_dp(
maybe_padded_num_tokens, total_num_scheduled_tokens, with_prefill)
maybe_padded_num_tokens, total_num_scheduled_tokens, with_prefill,
enable_dbo)
extra_builder_kwargs['enable_dbo_across_dp'] = enable_dbo
if self.torchair_graph_enabled and not with_prefill:
graph_pad_size = padded_num_tokens_across_dp - total_num_scheduled_tokens
@@ -1739,8 +1766,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Padding for DP
(num_tokens, num_tokens_across_dp, with_prefill,
enable_dbo) = self._get_forward_metadata_across_dp(
maybe_padded_num_tokens, num_tokens, with_prefill, False)
_) = self._get_forward_metadata_across_dp(maybe_padded_num_tokens,
num_tokens, with_prefill,
False)
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
# for dummy run with LoRA so that the num_reqs collectively