diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 668c802..6f7473f 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -150,6 +150,8 @@ class AscendMetadata: # (num_tokens,) slot_mapping: torch.Tensor = None + enable_dbo_across_dp: bool = False + class AscendAttentionMetadataBuilder: @@ -160,7 +162,11 @@ class AscendAttentionMetadataBuilder: scheduler_output: "SchedulerOutput") -> bool: return False - def build(self, num_reqs, num_actual_tokens, max_query_len): + def build(self, + num_reqs, + num_actual_tokens, + max_query_len, + enable_dbo_across_dp: bool = False): block_table = self.runner.input_batch.block_table[0].get_device_tensor( ) @@ -187,15 +193,17 @@ class AscendAttentionMetadataBuilder: attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), ACL_FORMAT_FRACTAL_NZ) - attn_metadata = AscendMetadata(num_actual_tokens=num_actual_tokens, - block_tables=block_table, - query_start_loc=query_start_loc, - query_lens=query_lens, - seq_lens=seq_lens, - max_query_len=max_query_len, - slot_mapping=slot_mapping, - attn_mask=attn_mask, - attn_state=attn_state) + attn_metadata = AscendMetadata( + num_actual_tokens=num_actual_tokens, + block_tables=block_table, + query_start_loc=query_start_loc, + query_lens=query_lens, + seq_lens=seq_lens, + max_query_len=max_query_len, + slot_mapping=slot_mapping, + attn_mask=attn_mask, + attn_state=attn_state, + enable_dbo_across_dp=enable_dbo_across_dp) return attn_metadata diff --git a/vllm_ascend/attention/attention_v1_torchair.py b/vllm_ascend/attention/attention_v1_torchair.py index 48437b4..fe7eb9d 100644 --- a/vllm_ascend/attention/attention_v1_torchair.py +++ b/vllm_ascend/attention/attention_v1_torchair.py @@ -140,6 +140,8 @@ class AscendTorchairMetadata: decode: Optional[AscendDecodeMetadata] = None + enable_dbo_across_dp: bool = False + class AscendAttentionTorchairMetadataBuilder: @@ -220,7 +222,8 @@ class AscendAttentionTorchairMetadataBuilder: num_reqs, num_actual_tokens, max_query_len, - graph_pad_size: int = -1): + graph_pad_size: int = -1, + enable_dbo_across_dp: bool = False): device = self.runner.device @@ -298,7 +301,8 @@ class AscendAttentionTorchairMetadataBuilder: max_query_len=max_query_len, slot_mapping=slot_mapping, attn_mask=attn_mask, - attn_state=attn_state) + attn_state=attn_state, + enable_dbo_across_dp=enable_dbo_across_dp) return attn_metadata diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 7771632..b2b3ad0 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -137,6 +137,7 @@ class AscendMLAMetadata: decode: Optional[AscendMLADecodeMetadata] = None prefill: Optional[AscendMLAPrefillMetadata] = None + enable_dbo_across_dp: bool = False def __post_init__(self): pass @@ -370,6 +371,7 @@ class AscendMLAMetadataBuilder: max_query_len: int, graph_pad_size: int = -1, query_start_loc: torch.Tensor = None, + enable_dbo_across_dp: bool = False, ) -> AscendMLAMetadata: assert self._num_decodes + self._num_prefills == num_reqs @@ -536,6 +538,7 @@ class AscendMLAMetadataBuilder: query_start_loc=query_start_loc, block_tables=block_table, seq_lens=seq_lens, + enable_dbo_across_dp=enable_dbo_across_dp, ) diff --git a/vllm_ascend/models/deepseek_dbo.py b/vllm_ascend/models/deepseek_dbo.py index 13e5efa..9469e99 100644 --- a/vllm_ascend/models/deepseek_dbo.py +++ b/vllm_ascend/models/deepseek_dbo.py @@ -75,7 +75,6 @@ from vllm_ascend.multistream.layers import (MultiStreamPostTransformerLayer, from vllm_ascend.multistream.metadata import (MultiStreamConfig, MultiStreamStepMetadata, make_multistream_metadata_ds) -from vllm_ascend.multistream.ms_split import compute_split_seq_index from vllm_ascend.ops.fused_moe import AscendFusedMoE from vllm_ascend.utils import dispose_tensor @@ -872,24 +871,9 @@ class CustomDeepseekDBOModel(nn.Module): def can_run_ms(self): attn_metadata = get_forward_context().attn_metadata - # support mla attention and V1 engine at present - if not self.use_mla: - return False # enable prefill overlap - if attn_metadata is None or attn_metadata.num_prefills == 0: - return False - else: - [token_index, seq_index - ] = compute_split_seq_index(attn_metadata.query_lens, - attn_metadata.attn_state, - attn_metadata.num_decode_tokens) - if token_index == 0 or seq_index == 0 or seq_index == len( - attn_metadata.query_lens): - return False - # check whether the total tokens exceed the threshold - if self.multistream_config is None or attn_metadata.num_actual_tokens < self.multistream_config.min_total_tokens_to_split: - return False - return True + return not (attn_metadata is None or attn_metadata.num_prefills == 0 + or not attn_metadata.enable_dbo_across_dp) def _forward_ms_layers( self, diff --git a/vllm_ascend/multistream/ms_split.py b/vllm_ascend/multistream/ms_split.py index 3af6337..fd32a18 100644 --- a/vllm_ascend/multistream/ms_split.py +++ b/vllm_ascend/multistream/ms_split.py @@ -96,10 +96,12 @@ def model_input_split_v1_mla_attn( seq_lens = attn_metadata.prefill.seq_lens if attn_metadata.num_prefills > 0 else attn_metadata.decode.seq_lens [seq_lens_pre, seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index) - query_start_loc_pre = attn_metadata.query_start_loc[:seq_index + 1] - query_start_loc_post = deepcopy( - attn_metadata.query_start_loc[seq_index:] - ) - attn_metadata.query_start_loc[seq_index] + query_start_loc_pre = query_start_loc_post = None + if attn_metadata.query_start_loc is not None: + query_start_loc_pre = attn_metadata.query_start_loc[:seq_index + 1] + query_start_loc_post = deepcopy( + attn_metadata.query_start_loc[seq_index:] + ) - attn_metadata.query_start_loc[seq_index] [block_table_pre, block_table_post] = split_attn_tensor_type(attn_metadata.block_tables, seq_index) @@ -223,6 +225,7 @@ def model_input_split_v1_mla_attn( attn_mask=attn_mask_pre, prefill=prefill_pre, decode=decode_pre, + enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp, ) attention_metadata_post = _metadata_cls( num_actual_tokens=attn_metadata.num_actual_tokens - token_index, @@ -239,5 +242,6 @@ def model_input_split_v1_mla_attn( attn_state=attn_state_post, prefill=prefill_post, decode=decode_post, + enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp, ) return [attention_metadata_pre, attention_metadata_post] diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 9782e17..55bb562 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -79,6 +79,7 @@ from vllm_ascend.attention.attention_v1 import (AscendAttentionState, AscendMetadata) from vllm_ascend.attention.attention_v1_torchair import AscendTorchairMetadata from vllm_ascend.attention.mla_v1 import AscendMLAMetadata +from vllm_ascend.multistream.ms_split import compute_split_seq_index from vllm_ascend.platform import NPUPlatform from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler from vllm_ascend.torchair.utils import (check_torchair_cache_exist, @@ -606,6 +607,27 @@ class NPUModelRunner(LoRAModelRunnerMixin): return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, not bool( forward_metadata[-1]) + def _check_dbo_is_valid(self, query_lens: torch.Tensor, + attn_state: AscendAttentionState, + num_tokens: int) -> bool: + # do the checks for dp + dbo + if attn_state in [ + AscendAttentionState.DecodeOnly, + AscendAttentionState.SpecDecoding + ]: + return False + # considering the case that one dp rank may enable dbo while others may not + if not self.vllm_config.model_config.use_mla or not envs_ascend.VLLM_ASCEND_ENABLE_DBO: + return False + # TODO: remove it if token-level microbatch is enabled + [token_index, + seq_index] = compute_split_seq_index(query_lens, attn_state, + num_tokens) + if token_index == 0 or seq_index == 0 or seq_index == len( + query_lens) or num_tokens < 256: + return False + return True + def get_eagle_atten_dict( self, scheduler_output: "SchedulerOutput", @@ -1080,6 +1102,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): with_prefill = attn_state not in [ AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding ] + enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(), + attn_state, + total_num_scheduled_tokens) maybe_padded_num_tokens = total_num_scheduled_tokens if self.torchair_graph_enabled and not with_prefill: @@ -1087,7 +1112,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): total_num_scheduled_tokens) (padded_num_tokens_across_dp, num_tokens_across_dp, with_prefill, enable_dbo) = self._get_forward_metadata_across_dp( - maybe_padded_num_tokens, total_num_scheduled_tokens, with_prefill) + maybe_padded_num_tokens, total_num_scheduled_tokens, with_prefill, + enable_dbo) + extra_builder_kwargs['enable_dbo_across_dp'] = enable_dbo if self.torchair_graph_enabled and not with_prefill: graph_pad_size = padded_num_tokens_across_dp - total_num_scheduled_tokens @@ -1739,8 +1766,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): # Padding for DP (num_tokens, num_tokens_across_dp, with_prefill, - enable_dbo) = self._get_forward_metadata_across_dp( - maybe_padded_num_tokens, num_tokens, with_prefill, False) + _) = self._get_forward_metadata_across_dp(maybe_padded_num_tokens, + num_tokens, with_prefill, + False) # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively