[main] remove dbo code (#3712)
### What this PR does / why we need it?
Remove codes of dbo.
Currently, vLLM has supported dbo with pr:
https://github.com/vllm-project/vllm/pull/23693.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0rc3
- vLLM main:
17c540a993
Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
@@ -121,7 +121,6 @@ from vllm_ascend.eplb.core.eplb_utils import EPLBParamUtils
|
||||
from vllm_ascend.eplb.core.eplb_worker import EplbProcess
|
||||
from vllm_ascend.eplb.eplb_updator import EplbUpdator
|
||||
from vllm_ascend.eplb.utils import model_register
|
||||
from vllm_ascend.multistream.ms_split import compute_split_seq_index
|
||||
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.sample.logits_processor import build_logitsprocs
|
||||
@@ -859,8 +858,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
|
||||
def _sync_metadata_across_dp(
|
||||
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
|
||||
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
|
||||
self, num_tokens: int,
|
||||
with_prefill: bool) -> tuple[int, Optional[torch.Tensor], bool]:
|
||||
# TODO: In vLLM, the only thing that needs to be synced is num_tokens, but in
|
||||
# our case, we still need to sync the other two flags as well. So we need to
|
||||
# include them in the all_reduce operation, and more over, we CANNOT skip it
|
||||
@@ -868,31 +867,29 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# FIXME: Restore the `or self.vllm_config.model_config.enforce_eager` here
|
||||
# immediately once the other two flags are no longer needed.
|
||||
if self.dp_size == 1:
|
||||
return num_tokens, None, with_prefill, enable_dbo
|
||||
return num_tokens, None, with_prefill
|
||||
|
||||
# Sync num_tokens, with_prefill, enable_dbo across dp ranks
|
||||
# Sync num_tokens, with_prefill across dp ranks
|
||||
num_tokens_tensor = torch.tensor([
|
||||
num_tokens if i == self.dp_rank else 0 for i in range(self.dp_size)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device="npu")
|
||||
|
||||
flags_tensor = torch.tensor(
|
||||
[int(with_prefill), int(not enable_dbo)],
|
||||
dtype=torch.int32,
|
||||
device="npu")
|
||||
flags_tensor = torch.tensor([int(with_prefill)],
|
||||
dtype=torch.int32,
|
||||
device="npu")
|
||||
|
||||
packed_tensor = torch.cat([num_tokens_tensor, flags_tensor])
|
||||
|
||||
dist.all_reduce(packed_tensor, group=get_dp_group().device_group)
|
||||
|
||||
# Unpack the results
|
||||
num_tokens_across_dp = packed_tensor[:-2]
|
||||
synced_flags = packed_tensor[-2:]
|
||||
num_tokens_across_dp = packed_tensor[:-1]
|
||||
synced_flags = packed_tensor[-1:]
|
||||
|
||||
max_tokens_across_dp = torch.max(num_tokens_across_dp).item()
|
||||
global_with_prefill = bool(synced_flags[0])
|
||||
global_enable_dbo = not bool(synced_flags[1])
|
||||
|
||||
# Create a tensor for num_tokens_after_padding
|
||||
num_tokens_after_padding = torch.tensor([max_tokens_across_dp] *
|
||||
@@ -900,28 +897,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
|
||||
return max_tokens_across_dp, num_tokens_after_padding, global_with_prefill, global_enable_dbo
|
||||
|
||||
def _check_dbo_is_valid(self, query_lens: torch.Tensor,
|
||||
attn_state: AscendAttentionState,
|
||||
num_tokens: int) -> bool:
|
||||
# do the checks for dp + dbo
|
||||
if attn_state in [
|
||||
AscendAttentionState.DecodeOnly,
|
||||
AscendAttentionState.SpecDecoding
|
||||
]:
|
||||
return False
|
||||
# considering the case that one dp rank may enable dbo while others may not
|
||||
if not self.vllm_config.model_config.use_mla or not envs_ascend.VLLM_ASCEND_ENABLE_DBO:
|
||||
return False
|
||||
# TODO: remove it if token-level microbatch is enabled
|
||||
[token_index,
|
||||
seq_index] = compute_split_seq_index(query_lens, attn_state,
|
||||
num_tokens)
|
||||
if token_index == 0 or seq_index == 0 or seq_index == len(
|
||||
query_lens) or num_tokens < 256:
|
||||
return False
|
||||
return True
|
||||
return max_tokens_across_dp, num_tokens_after_padding, global_with_prefill
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
# get raw model out of the aclgraph wrapper.
|
||||
@@ -1430,16 +1406,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
]
|
||||
|
||||
self.query_lens = torch.from_numpy(num_scheduled_tokens)
|
||||
enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(),
|
||||
attn_state,
|
||||
total_num_scheduled_tokens)
|
||||
|
||||
# Get info across DP ranks.
|
||||
# NOTE: maybe_padded_num_tokens is only used when using TorchAir with DP,
|
||||
# Otherwise, it's just max_tokens_across_dp_cpu
|
||||
(maybe_padded_num_tokens, num_tokens_across_dp, with_prefill,
|
||||
enable_dbo) = self._sync_metadata_across_dp(num_input_tokens,
|
||||
with_prefill, enable_dbo)
|
||||
(maybe_padded_num_tokens, num_tokens_across_dp,
|
||||
with_prefill) = self._sync_metadata_across_dp(num_input_tokens,
|
||||
with_prefill)
|
||||
|
||||
# TODO: Now that num_input_tokens is basically identical with maybe_padded_num_tokens
|
||||
# We should consider removing maybe_padded_num_tokens later
|
||||
@@ -1707,7 +1680,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
attn_mask=self.attn_mask,
|
||||
spec_attn_mask=self.spec_attn_mask,
|
||||
attn_state=self.attn_state,
|
||||
enable_dbo_across_dp=enable_dbo,
|
||||
is_only_prefill=bool(np.all(num_valid_tokens != 1)),
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
graph_pad_size=self.graph_pad_size,
|
||||
@@ -2603,8 +2575,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_tokens = math.ceil(num_tokens / tp_size) * tp_size
|
||||
|
||||
# Padding for DP
|
||||
(num_tokens, num_tokens_across_dp, with_prefill,
|
||||
_) = self._sync_metadata_across_dp(num_tokens, with_prefill, False)
|
||||
(num_tokens, num_tokens_across_dp,
|
||||
with_prefill) = self._sync_metadata_across_dp(num_tokens,
|
||||
with_prefill)
|
||||
|
||||
moe_comm_type = self._select_moe_comm_method(num_tokens, with_prefill)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user