### What this PR does / why we need it?
Cherry pick #1291 from v0.9.1-dev, This pr implement the synchronization
of whether `dbo` is enabled across all dp ranks. specifically, it
performed allreduce op across multiple DP ranks, only when all the dp
rank is `enable_dbo`, it is enabled
Co-authored-by: shikang-hangzhou <459956190@qq.com>
Co-authored-by: wangli <wangli858794774@gmail.com>
- vLLM version: v0.10.0
- vLLM main:
2836dd73f1
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
@@ -150,6 +150,8 @@ class AscendMetadata:
|
|||||||
# (num_tokens,)
|
# (num_tokens,)
|
||||||
slot_mapping: torch.Tensor = None
|
slot_mapping: torch.Tensor = None
|
||||||
|
|
||||||
|
enable_dbo_across_dp: bool = False
|
||||||
|
|
||||||
|
|
||||||
class AscendAttentionMetadataBuilder:
|
class AscendAttentionMetadataBuilder:
|
||||||
|
|
||||||
@@ -160,7 +162,11 @@ class AscendAttentionMetadataBuilder:
|
|||||||
scheduler_output: "SchedulerOutput") -> bool:
|
scheduler_output: "SchedulerOutput") -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def build(self, num_reqs, num_actual_tokens, max_query_len):
|
def build(self,
|
||||||
|
num_reqs,
|
||||||
|
num_actual_tokens,
|
||||||
|
max_query_len,
|
||||||
|
enable_dbo_across_dp: bool = False):
|
||||||
|
|
||||||
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
|
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
|
||||||
)
|
)
|
||||||
@@ -187,15 +193,17 @@ class AscendAttentionMetadataBuilder:
|
|||||||
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
|
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
|
||||||
ACL_FORMAT_FRACTAL_NZ)
|
ACL_FORMAT_FRACTAL_NZ)
|
||||||
|
|
||||||
attn_metadata = AscendMetadata(num_actual_tokens=num_actual_tokens,
|
attn_metadata = AscendMetadata(
|
||||||
block_tables=block_table,
|
num_actual_tokens=num_actual_tokens,
|
||||||
query_start_loc=query_start_loc,
|
block_tables=block_table,
|
||||||
query_lens=query_lens,
|
query_start_loc=query_start_loc,
|
||||||
seq_lens=seq_lens,
|
query_lens=query_lens,
|
||||||
max_query_len=max_query_len,
|
seq_lens=seq_lens,
|
||||||
slot_mapping=slot_mapping,
|
max_query_len=max_query_len,
|
||||||
attn_mask=attn_mask,
|
slot_mapping=slot_mapping,
|
||||||
attn_state=attn_state)
|
attn_mask=attn_mask,
|
||||||
|
attn_state=attn_state,
|
||||||
|
enable_dbo_across_dp=enable_dbo_across_dp)
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -140,6 +140,8 @@ class AscendTorchairMetadata:
|
|||||||
|
|
||||||
decode: Optional[AscendDecodeMetadata] = None
|
decode: Optional[AscendDecodeMetadata] = None
|
||||||
|
|
||||||
|
enable_dbo_across_dp: bool = False
|
||||||
|
|
||||||
|
|
||||||
class AscendAttentionTorchairMetadataBuilder:
|
class AscendAttentionTorchairMetadataBuilder:
|
||||||
|
|
||||||
@@ -220,7 +222,8 @@ class AscendAttentionTorchairMetadataBuilder:
|
|||||||
num_reqs,
|
num_reqs,
|
||||||
num_actual_tokens,
|
num_actual_tokens,
|
||||||
max_query_len,
|
max_query_len,
|
||||||
graph_pad_size: int = -1):
|
graph_pad_size: int = -1,
|
||||||
|
enable_dbo_across_dp: bool = False):
|
||||||
|
|
||||||
device = self.runner.device
|
device = self.runner.device
|
||||||
|
|
||||||
@@ -298,7 +301,8 @@ class AscendAttentionTorchairMetadataBuilder:
|
|||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
attn_state=attn_state)
|
attn_state=attn_state,
|
||||||
|
enable_dbo_across_dp=enable_dbo_across_dp)
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -137,6 +137,7 @@ class AscendMLAMetadata:
|
|||||||
|
|
||||||
decode: Optional[AscendMLADecodeMetadata] = None
|
decode: Optional[AscendMLADecodeMetadata] = None
|
||||||
prefill: Optional[AscendMLAPrefillMetadata] = None
|
prefill: Optional[AscendMLAPrefillMetadata] = None
|
||||||
|
enable_dbo_across_dp: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
pass
|
pass
|
||||||
@@ -370,6 +371,7 @@ class AscendMLAMetadataBuilder:
|
|||||||
max_query_len: int,
|
max_query_len: int,
|
||||||
graph_pad_size: int = -1,
|
graph_pad_size: int = -1,
|
||||||
query_start_loc: torch.Tensor = None,
|
query_start_loc: torch.Tensor = None,
|
||||||
|
enable_dbo_across_dp: bool = False,
|
||||||
) -> AscendMLAMetadata:
|
) -> AscendMLAMetadata:
|
||||||
assert self._num_decodes + self._num_prefills == num_reqs
|
assert self._num_decodes + self._num_prefills == num_reqs
|
||||||
|
|
||||||
@@ -536,6 +538,7 @@ class AscendMLAMetadataBuilder:
|
|||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
block_tables=block_table,
|
block_tables=block_table,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
|
enable_dbo_across_dp=enable_dbo_across_dp,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -75,7 +75,6 @@ from vllm_ascend.multistream.layers import (MultiStreamPostTransformerLayer,
|
|||||||
from vllm_ascend.multistream.metadata import (MultiStreamConfig,
|
from vllm_ascend.multistream.metadata import (MultiStreamConfig,
|
||||||
MultiStreamStepMetadata,
|
MultiStreamStepMetadata,
|
||||||
make_multistream_metadata_ds)
|
make_multistream_metadata_ds)
|
||||||
from vllm_ascend.multistream.ms_split import compute_split_seq_index
|
|
||||||
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
||||||
from vllm_ascend.utils import dispose_tensor
|
from vllm_ascend.utils import dispose_tensor
|
||||||
|
|
||||||
@@ -872,24 +871,9 @@ class CustomDeepseekDBOModel(nn.Module):
|
|||||||
|
|
||||||
def can_run_ms(self):
|
def can_run_ms(self):
|
||||||
attn_metadata = get_forward_context().attn_metadata
|
attn_metadata = get_forward_context().attn_metadata
|
||||||
# support mla attention and V1 engine at present
|
|
||||||
if not self.use_mla:
|
|
||||||
return False
|
|
||||||
# enable prefill overlap
|
# enable prefill overlap
|
||||||
if attn_metadata is None or attn_metadata.num_prefills == 0:
|
return not (attn_metadata is None or attn_metadata.num_prefills == 0
|
||||||
return False
|
or not attn_metadata.enable_dbo_across_dp)
|
||||||
else:
|
|
||||||
[token_index, seq_index
|
|
||||||
] = compute_split_seq_index(attn_metadata.query_lens,
|
|
||||||
attn_metadata.attn_state,
|
|
||||||
attn_metadata.num_decode_tokens)
|
|
||||||
if token_index == 0 or seq_index == 0 or seq_index == len(
|
|
||||||
attn_metadata.query_lens):
|
|
||||||
return False
|
|
||||||
# check whether the total tokens exceed the threshold
|
|
||||||
if self.multistream_config is None or attn_metadata.num_actual_tokens < self.multistream_config.min_total_tokens_to_split:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _forward_ms_layers(
|
def _forward_ms_layers(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -96,10 +96,12 @@ def model_input_split_v1_mla_attn(
|
|||||||
seq_lens = attn_metadata.prefill.seq_lens if attn_metadata.num_prefills > 0 else attn_metadata.decode.seq_lens
|
seq_lens = attn_metadata.prefill.seq_lens if attn_metadata.num_prefills > 0 else attn_metadata.decode.seq_lens
|
||||||
[seq_lens_pre, seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index)
|
[seq_lens_pre, seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index)
|
||||||
|
|
||||||
query_start_loc_pre = attn_metadata.query_start_loc[:seq_index + 1]
|
query_start_loc_pre = query_start_loc_post = None
|
||||||
query_start_loc_post = deepcopy(
|
if attn_metadata.query_start_loc is not None:
|
||||||
attn_metadata.query_start_loc[seq_index:]
|
query_start_loc_pre = attn_metadata.query_start_loc[:seq_index + 1]
|
||||||
) - attn_metadata.query_start_loc[seq_index]
|
query_start_loc_post = deepcopy(
|
||||||
|
attn_metadata.query_start_loc[seq_index:]
|
||||||
|
) - attn_metadata.query_start_loc[seq_index]
|
||||||
[block_table_pre,
|
[block_table_pre,
|
||||||
block_table_post] = split_attn_tensor_type(attn_metadata.block_tables,
|
block_table_post] = split_attn_tensor_type(attn_metadata.block_tables,
|
||||||
seq_index)
|
seq_index)
|
||||||
@@ -223,6 +225,7 @@ def model_input_split_v1_mla_attn(
|
|||||||
attn_mask=attn_mask_pre,
|
attn_mask=attn_mask_pre,
|
||||||
prefill=prefill_pre,
|
prefill=prefill_pre,
|
||||||
decode=decode_pre,
|
decode=decode_pre,
|
||||||
|
enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp,
|
||||||
)
|
)
|
||||||
attention_metadata_post = _metadata_cls(
|
attention_metadata_post = _metadata_cls(
|
||||||
num_actual_tokens=attn_metadata.num_actual_tokens - token_index,
|
num_actual_tokens=attn_metadata.num_actual_tokens - token_index,
|
||||||
@@ -239,5 +242,6 @@ def model_input_split_v1_mla_attn(
|
|||||||
attn_state=attn_state_post,
|
attn_state=attn_state_post,
|
||||||
prefill=prefill_post,
|
prefill=prefill_post,
|
||||||
decode=decode_post,
|
decode=decode_post,
|
||||||
|
enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp,
|
||||||
)
|
)
|
||||||
return [attention_metadata_pre, attention_metadata_post]
|
return [attention_metadata_pre, attention_metadata_post]
|
||||||
|
|||||||
@@ -79,6 +79,7 @@ from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
|
|||||||
AscendMetadata)
|
AscendMetadata)
|
||||||
from vllm_ascend.attention.attention_v1_torchair import AscendTorchairMetadata
|
from vllm_ascend.attention.attention_v1_torchair import AscendTorchairMetadata
|
||||||
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
|
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
|
||||||
|
from vllm_ascend.multistream.ms_split import compute_split_seq_index
|
||||||
from vllm_ascend.platform import NPUPlatform
|
from vllm_ascend.platform import NPUPlatform
|
||||||
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
||||||
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
|
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
|
||||||
@@ -606,6 +607,27 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, not bool(
|
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, not bool(
|
||||||
forward_metadata[-1])
|
forward_metadata[-1])
|
||||||
|
|
||||||
|
def _check_dbo_is_valid(self, query_lens: torch.Tensor,
|
||||||
|
attn_state: AscendAttentionState,
|
||||||
|
num_tokens: int) -> bool:
|
||||||
|
# do the checks for dp + dbo
|
||||||
|
if attn_state in [
|
||||||
|
AscendAttentionState.DecodeOnly,
|
||||||
|
AscendAttentionState.SpecDecoding
|
||||||
|
]:
|
||||||
|
return False
|
||||||
|
# considering the case that one dp rank may enable dbo while others may not
|
||||||
|
if not self.vllm_config.model_config.use_mla or not envs_ascend.VLLM_ASCEND_ENABLE_DBO:
|
||||||
|
return False
|
||||||
|
# TODO: remove it if token-level microbatch is enabled
|
||||||
|
[token_index,
|
||||||
|
seq_index] = compute_split_seq_index(query_lens, attn_state,
|
||||||
|
num_tokens)
|
||||||
|
if token_index == 0 or seq_index == 0 or seq_index == len(
|
||||||
|
query_lens) or num_tokens < 256:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
def get_eagle_atten_dict(
|
def get_eagle_atten_dict(
|
||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
@@ -1080,6 +1102,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
with_prefill = attn_state not in [
|
with_prefill = attn_state not in [
|
||||||
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
||||||
]
|
]
|
||||||
|
enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(),
|
||||||
|
attn_state,
|
||||||
|
total_num_scheduled_tokens)
|
||||||
|
|
||||||
maybe_padded_num_tokens = total_num_scheduled_tokens
|
maybe_padded_num_tokens = total_num_scheduled_tokens
|
||||||
if self.torchair_graph_enabled and not with_prefill:
|
if self.torchair_graph_enabled and not with_prefill:
|
||||||
@@ -1087,7 +1112,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
total_num_scheduled_tokens)
|
total_num_scheduled_tokens)
|
||||||
(padded_num_tokens_across_dp, num_tokens_across_dp, with_prefill,
|
(padded_num_tokens_across_dp, num_tokens_across_dp, with_prefill,
|
||||||
enable_dbo) = self._get_forward_metadata_across_dp(
|
enable_dbo) = self._get_forward_metadata_across_dp(
|
||||||
maybe_padded_num_tokens, total_num_scheduled_tokens, with_prefill)
|
maybe_padded_num_tokens, total_num_scheduled_tokens, with_prefill,
|
||||||
|
enable_dbo)
|
||||||
|
extra_builder_kwargs['enable_dbo_across_dp'] = enable_dbo
|
||||||
|
|
||||||
if self.torchair_graph_enabled and not with_prefill:
|
if self.torchair_graph_enabled and not with_prefill:
|
||||||
graph_pad_size = padded_num_tokens_across_dp - total_num_scheduled_tokens
|
graph_pad_size = padded_num_tokens_across_dp - total_num_scheduled_tokens
|
||||||
@@ -1739,8 +1766,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
# Padding for DP
|
# Padding for DP
|
||||||
(num_tokens, num_tokens_across_dp, with_prefill,
|
(num_tokens, num_tokens_across_dp, with_prefill,
|
||||||
enable_dbo) = self._get_forward_metadata_across_dp(
|
_) = self._get_forward_metadata_across_dp(maybe_padded_num_tokens,
|
||||||
maybe_padded_num_tokens, num_tokens, with_prefill, False)
|
num_tokens, with_prefill,
|
||||||
|
False)
|
||||||
|
|
||||||
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
|
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
|
||||||
# for dummy run with LoRA so that the num_reqs collectively
|
# for dummy run with LoRA so that the num_reqs collectively
|
||||||
|
|||||||
Reference in New Issue
Block a user