### What this PR does / why we need it?
Cherry pick #1291 from v0.9.1-dev, This pr implement the synchronization
of whether `dbo` is enabled across all dp ranks. specifically, it
performed allreduce op across multiple DP ranks, only when all the dp
rank is `enable_dbo`, it is enabled
Co-authored-by: shikang-hangzhou <459956190@qq.com>
Co-authored-by: wangli <wangli858794774@gmail.com>
- vLLM version: v0.10.0
- vLLM main:
2836dd73f1
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
@@ -150,6 +150,8 @@ class AscendMetadata:
|
||||
# (num_tokens,)
|
||||
slot_mapping: torch.Tensor = None
|
||||
|
||||
enable_dbo_across_dp: bool = False
|
||||
|
||||
|
||||
class AscendAttentionMetadataBuilder:
|
||||
|
||||
@@ -160,7 +162,11 @@ class AscendAttentionMetadataBuilder:
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
return False
|
||||
|
||||
def build(self, num_reqs, num_actual_tokens, max_query_len):
|
||||
def build(self,
|
||||
num_reqs,
|
||||
num_actual_tokens,
|
||||
max_query_len,
|
||||
enable_dbo_across_dp: bool = False):
|
||||
|
||||
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
|
||||
)
|
||||
@@ -187,15 +193,17 @@ class AscendAttentionMetadataBuilder:
|
||||
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
attn_metadata = AscendMetadata(num_actual_tokens=num_actual_tokens,
|
||||
block_tables=block_table,
|
||||
query_start_loc=query_start_loc,
|
||||
query_lens=query_lens,
|
||||
seq_lens=seq_lens,
|
||||
max_query_len=max_query_len,
|
||||
slot_mapping=slot_mapping,
|
||||
attn_mask=attn_mask,
|
||||
attn_state=attn_state)
|
||||
attn_metadata = AscendMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
block_tables=block_table,
|
||||
query_start_loc=query_start_loc,
|
||||
query_lens=query_lens,
|
||||
seq_lens=seq_lens,
|
||||
max_query_len=max_query_len,
|
||||
slot_mapping=slot_mapping,
|
||||
attn_mask=attn_mask,
|
||||
attn_state=attn_state,
|
||||
enable_dbo_across_dp=enable_dbo_across_dp)
|
||||
return attn_metadata
|
||||
|
||||
|
||||
|
||||
@@ -140,6 +140,8 @@ class AscendTorchairMetadata:
|
||||
|
||||
decode: Optional[AscendDecodeMetadata] = None
|
||||
|
||||
enable_dbo_across_dp: bool = False
|
||||
|
||||
|
||||
class AscendAttentionTorchairMetadataBuilder:
|
||||
|
||||
@@ -220,7 +222,8 @@ class AscendAttentionTorchairMetadataBuilder:
|
||||
num_reqs,
|
||||
num_actual_tokens,
|
||||
max_query_len,
|
||||
graph_pad_size: int = -1):
|
||||
graph_pad_size: int = -1,
|
||||
enable_dbo_across_dp: bool = False):
|
||||
|
||||
device = self.runner.device
|
||||
|
||||
@@ -298,7 +301,8 @@ class AscendAttentionTorchairMetadataBuilder:
|
||||
max_query_len=max_query_len,
|
||||
slot_mapping=slot_mapping,
|
||||
attn_mask=attn_mask,
|
||||
attn_state=attn_state)
|
||||
attn_state=attn_state,
|
||||
enable_dbo_across_dp=enable_dbo_across_dp)
|
||||
return attn_metadata
|
||||
|
||||
|
||||
|
||||
@@ -137,6 +137,7 @@ class AscendMLAMetadata:
|
||||
|
||||
decode: Optional[AscendMLADecodeMetadata] = None
|
||||
prefill: Optional[AscendMLAPrefillMetadata] = None
|
||||
enable_dbo_across_dp: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
pass
|
||||
@@ -370,6 +371,7 @@ class AscendMLAMetadataBuilder:
|
||||
max_query_len: int,
|
||||
graph_pad_size: int = -1,
|
||||
query_start_loc: torch.Tensor = None,
|
||||
enable_dbo_across_dp: bool = False,
|
||||
) -> AscendMLAMetadata:
|
||||
assert self._num_decodes + self._num_prefills == num_reqs
|
||||
|
||||
@@ -536,6 +538,7 @@ class AscendMLAMetadataBuilder:
|
||||
query_start_loc=query_start_loc,
|
||||
block_tables=block_table,
|
||||
seq_lens=seq_lens,
|
||||
enable_dbo_across_dp=enable_dbo_across_dp,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -75,7 +75,6 @@ from vllm_ascend.multistream.layers import (MultiStreamPostTransformerLayer,
|
||||
from vllm_ascend.multistream.metadata import (MultiStreamConfig,
|
||||
MultiStreamStepMetadata,
|
||||
make_multistream_metadata_ds)
|
||||
from vllm_ascend.multistream.ms_split import compute_split_seq_index
|
||||
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
||||
from vllm_ascend.utils import dispose_tensor
|
||||
|
||||
@@ -872,24 +871,9 @@ class CustomDeepseekDBOModel(nn.Module):
|
||||
|
||||
def can_run_ms(self):
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
# support mla attention and V1 engine at present
|
||||
if not self.use_mla:
|
||||
return False
|
||||
# enable prefill overlap
|
||||
if attn_metadata is None or attn_metadata.num_prefills == 0:
|
||||
return False
|
||||
else:
|
||||
[token_index, seq_index
|
||||
] = compute_split_seq_index(attn_metadata.query_lens,
|
||||
attn_metadata.attn_state,
|
||||
attn_metadata.num_decode_tokens)
|
||||
if token_index == 0 or seq_index == 0 or seq_index == len(
|
||||
attn_metadata.query_lens):
|
||||
return False
|
||||
# check whether the total tokens exceed the threshold
|
||||
if self.multistream_config is None or attn_metadata.num_actual_tokens < self.multistream_config.min_total_tokens_to_split:
|
||||
return False
|
||||
return True
|
||||
return not (attn_metadata is None or attn_metadata.num_prefills == 0
|
||||
or not attn_metadata.enable_dbo_across_dp)
|
||||
|
||||
def _forward_ms_layers(
|
||||
self,
|
||||
|
||||
@@ -96,10 +96,12 @@ def model_input_split_v1_mla_attn(
|
||||
seq_lens = attn_metadata.prefill.seq_lens if attn_metadata.num_prefills > 0 else attn_metadata.decode.seq_lens
|
||||
[seq_lens_pre, seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index)
|
||||
|
||||
query_start_loc_pre = attn_metadata.query_start_loc[:seq_index + 1]
|
||||
query_start_loc_post = deepcopy(
|
||||
attn_metadata.query_start_loc[seq_index:]
|
||||
) - attn_metadata.query_start_loc[seq_index]
|
||||
query_start_loc_pre = query_start_loc_post = None
|
||||
if attn_metadata.query_start_loc is not None:
|
||||
query_start_loc_pre = attn_metadata.query_start_loc[:seq_index + 1]
|
||||
query_start_loc_post = deepcopy(
|
||||
attn_metadata.query_start_loc[seq_index:]
|
||||
) - attn_metadata.query_start_loc[seq_index]
|
||||
[block_table_pre,
|
||||
block_table_post] = split_attn_tensor_type(attn_metadata.block_tables,
|
||||
seq_index)
|
||||
@@ -223,6 +225,7 @@ def model_input_split_v1_mla_attn(
|
||||
attn_mask=attn_mask_pre,
|
||||
prefill=prefill_pre,
|
||||
decode=decode_pre,
|
||||
enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp,
|
||||
)
|
||||
attention_metadata_post = _metadata_cls(
|
||||
num_actual_tokens=attn_metadata.num_actual_tokens - token_index,
|
||||
@@ -239,5 +242,6 @@ def model_input_split_v1_mla_attn(
|
||||
attn_state=attn_state_post,
|
||||
prefill=prefill_post,
|
||||
decode=decode_post,
|
||||
enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp,
|
||||
)
|
||||
return [attention_metadata_pre, attention_metadata_post]
|
||||
|
||||
@@ -79,6 +79,7 @@ from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
|
||||
AscendMetadata)
|
||||
from vllm_ascend.attention.attention_v1_torchair import AscendTorchairMetadata
|
||||
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
|
||||
from vllm_ascend.multistream.ms_split import compute_split_seq_index
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
||||
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
|
||||
@@ -606,6 +607,27 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, not bool(
|
||||
forward_metadata[-1])
|
||||
|
||||
def _check_dbo_is_valid(self, query_lens: torch.Tensor,
|
||||
attn_state: AscendAttentionState,
|
||||
num_tokens: int) -> bool:
|
||||
# do the checks for dp + dbo
|
||||
if attn_state in [
|
||||
AscendAttentionState.DecodeOnly,
|
||||
AscendAttentionState.SpecDecoding
|
||||
]:
|
||||
return False
|
||||
# considering the case that one dp rank may enable dbo while others may not
|
||||
if not self.vllm_config.model_config.use_mla or not envs_ascend.VLLM_ASCEND_ENABLE_DBO:
|
||||
return False
|
||||
# TODO: remove it if token-level microbatch is enabled
|
||||
[token_index,
|
||||
seq_index] = compute_split_seq_index(query_lens, attn_state,
|
||||
num_tokens)
|
||||
if token_index == 0 or seq_index == 0 or seq_index == len(
|
||||
query_lens) or num_tokens < 256:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_eagle_atten_dict(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
@@ -1080,6 +1102,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
with_prefill = attn_state not in [
|
||||
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
||||
]
|
||||
enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(),
|
||||
attn_state,
|
||||
total_num_scheduled_tokens)
|
||||
|
||||
maybe_padded_num_tokens = total_num_scheduled_tokens
|
||||
if self.torchair_graph_enabled and not with_prefill:
|
||||
@@ -1087,7 +1112,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
total_num_scheduled_tokens)
|
||||
(padded_num_tokens_across_dp, num_tokens_across_dp, with_prefill,
|
||||
enable_dbo) = self._get_forward_metadata_across_dp(
|
||||
maybe_padded_num_tokens, total_num_scheduled_tokens, with_prefill)
|
||||
maybe_padded_num_tokens, total_num_scheduled_tokens, with_prefill,
|
||||
enable_dbo)
|
||||
extra_builder_kwargs['enable_dbo_across_dp'] = enable_dbo
|
||||
|
||||
if self.torchair_graph_enabled and not with_prefill:
|
||||
graph_pad_size = padded_num_tokens_across_dp - total_num_scheduled_tokens
|
||||
@@ -1739,8 +1766,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# Padding for DP
|
||||
(num_tokens, num_tokens_across_dp, with_prefill,
|
||||
enable_dbo) = self._get_forward_metadata_across_dp(
|
||||
maybe_padded_num_tokens, num_tokens, with_prefill, False)
|
||||
_) = self._get_forward_metadata_across_dp(maybe_padded_num_tokens,
|
||||
num_tokens, with_prefill,
|
||||
False)
|
||||
|
||||
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
|
||||
# for dummy run with LoRA so that the num_reqs collectively
|
||||
|
||||
Reference in New Issue
Block a user