[main] remove dbo code (#3712)

### What this PR does / why we need it?
Remove codes of dbo.
Currently, vLLM has supported dbo with pr:
https://github.com/vllm-project/vllm/pull/23693.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.11.0rc3
- vLLM main:
17c540a993

Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
zzzzwwjj
2025-10-25 15:53:01 +08:00
committed by GitHub
parent d9cdc65854
commit e5676fc36e
26 changed files with 69 additions and 1588 deletions

View File

@@ -20,9 +20,6 @@ from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
split_decodes_and_prefills)
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.context import get_multistream_comm_context
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata,
npu_stream_switch, npu_wait_tensor)
@@ -141,7 +138,6 @@ class AscendMLATorchairMetadata:
decode: Optional[AscendMLATorchairDecodeMetadata] = None
prefill: Optional[AscendMLATorchairPrefillMetadata] = None
enable_dbo_across_dp: bool = False
def __post_init__(self):
pass
@@ -152,17 +148,6 @@ class AscendMLATorchairMetadata:
# f"Only {supported_head_sizes} are supported for head_dim,",
# f"received {self.head_dim}.")
def split_metadata_for_multistream(
self,
ms_split_config: MSAttentionMetadataSplitConfig,
) -> list["AscendMLATorchairMetadata"]:
"""Split metadata for multi-stream with AscendMLATorchairMetadata"""
return model_input_split_v1_mla_attn(
ms_split_config=ms_split_config,
attn_metadata=self,
_metadata_cls=AscendMLATorchairMetadata,
)
M = TypeVar("M", bound=AscendMLATorchairMetadata)
@@ -576,7 +561,6 @@ class AscendMLATorchairMetadataBuilder:
query_start_loc=query_start_loc,
block_tables=block_table,
seq_lens=seq_lens,
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
)
def pad_actual_seq_len_q(self, num_reqs_pad_size, num_reqs,
@@ -1072,15 +1056,8 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
context_lens=attn_metadata.decode.seq_lens, # type:ignore
mla_vheadsize=self.kv_lora_rank,
out=attn_output)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is None:
return self._v_up_proj_and_o_proj(attn_output,
enable_multistream_mla)
else:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
current_ms_metadata.before_comm_event.wait()
return self._v_up_proj_and_o_proj(attn_output)
return self._v_up_proj_and_o_proj(attn_output, enable_multistream_mla)
def forward(
self,
@@ -1248,14 +1225,7 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
prefill_k_c_normed,
prefill_k_pe, kv_cache,
attn_metadata)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is not None:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
current_ms_metadata.before_comm_event.wait()
o_proj_input[num_decode_tokens:] = output_prefill
else:
o_proj_input[num_decode_tokens:] = output_prefill
o_proj_input[num_decode_tokens:] = output_prefill
if has_decode:
if self.running_in_graph:
@@ -1269,35 +1239,19 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
decode_k_nope,
decode_k_pe, kv_cache,
attn_metadata)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is not None:
with torch.npu.stream(current_ms_metadata.comm_stream):
o_proj_input[:num_decode_tokens] = output_decode
else:
o_proj_input[:num_decode_tokens] = output_decode
o_proj_input[:num_decode_tokens] = output_decode
current_ms_metadata = get_multistream_comm_context()
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
if current_ms_metadata is None:
maybe_npu_prefetch(self.o_proj.weight,
o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=enable_multistream_mla)
output[...] = self.o_proj(
o_proj_input,
is_prefill=True,
is_force_scatter=self.enable_shared_expert_dp)[0]
else:
with torch.npu.stream(current_ms_metadata.comm_stream):
maybe_npu_prefetch(self.o_proj.weight,
o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=enable_multistream_mla)
output[...] = self.o_proj(
o_proj_input,
is_prefill=True,
is_force_scatter=self.enable_shared_expert_dp)[0]
current_ms_metadata.after_comm_event.record()
maybe_npu_prefetch(self.o_proj.weight,
o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=enable_multistream_mla)
output[...] = self.o_proj(
o_proj_input,
is_prefill=True,
is_force_scatter=self.enable_shared_expert_dp)[0]
del o_proj_input
return output_padded