[main] remove dbo code (#3712)

### What this PR does / why we need it?
Remove codes of dbo.
Currently, vLLM has supported dbo with pr:
https://github.com/vllm-project/vllm/pull/23693.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.11.0rc3
- vLLM main:
17c540a993

Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
zzzzwwjj
2025-10-25 15:53:01 +08:00
committed by GitHub
parent d9cdc65854
commit e5676fc36e
26 changed files with 69 additions and 1588 deletions

View File

@@ -36,9 +36,6 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
trans_rope_weight, transdata,
wait_for_kv_layer_from_connector)
from vllm_ascend.compilation.acl_graph import get_graph_params
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.context import get_multistream_comm_context
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
@@ -184,7 +181,6 @@ class AscendMLAMetadata:
decode: Optional[AscendMLADecodeMetadata] = None
prefill: Optional[AscendMLAPrefillMetadata] = None
enable_dbo_across_dp: bool = False
def __post_init__(self):
pass
@@ -195,17 +191,6 @@ class AscendMLAMetadata:
# f"Only {supported_head_sizes} are supported for head_dim,",
# f"received {self.head_dim}.")
def split_metadata_for_multistream(
self,
ms_split_config: MSAttentionMetadataSplitConfig,
) -> list["AscendMLAMetadata"]:
"""Split metadata for multi-stream with AscendMLAMetadata"""
return model_input_split_v1_mla_attn(
ms_split_config=ms_split_config,
attn_metadata=self,
_metadata_cls=AscendMLAMetadata,
)
M = TypeVar("M", bound=AscendMLAMetadata)
@@ -538,7 +523,6 @@ class AscendMLAMetadataBuilder:
query_start_loc=query_start_loc,
block_tables=block_table,
seq_lens=seq_lens,
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
)
def build_for_graph_capture(
@@ -1158,14 +1142,8 @@ class AscendMLAImpl(MLAAttentionImpl):
else:
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
q_nope, k_nope, k_nope, **common_kwargs)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is None:
return self._v_up_proj(attn_output)
else:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
current_ms_metadata.before_comm_event.wait()
return self._v_up_proj(attn_output)
return self._v_up_proj(attn_output)
def _mla_decode_preprocess(self, hidden_states, kv_cache, attn_metadata):
bsz = attn_metadata.num_decode_tokens
@@ -1423,13 +1401,8 @@ class AscendMLAImpl(MLAAttentionImpl):
decode_preprocess_res.ql_nope, decode_preprocess_res.q_pe,
decode_preprocess_res.k_nope, decode_preprocess_res.k_pe,
kv_cache[0].shape[1], attn_metadata)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is not None:
with torch.npu.stream(current_ms_metadata.comm_stream):
o_proj_input[:num_decode_tokens] = output_decode
current_ms_metadata.after_comm_event.record()
else:
o_proj_input[:num_decode_tokens] = output_decode
o_proj_input[:num_decode_tokens] = output_decode
if prefill_preprocess_res is not None:
# FIX: aicore move should be also placed on the comm stream in dbo,
@@ -1445,36 +1418,19 @@ class AscendMLAImpl(MLAAttentionImpl):
prefill_preprocess_res.q_nope, prefill_preprocess_res.q_pe,
prefill_preprocess_res.k_nope, prefill_preprocess_res.k_pe,
prefill_preprocess_res.value, kv_cache, attn_metadata)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is not None:
with torch.npu.stream(current_ms_metadata.comm_stream):
o_proj_input[num_decode_tokens:] = output_prefill
current_ms_metadata.after_comm_event.record()
else:
o_proj_input[
num_decode_tokens:num_actual_tokens] = output_prefill
# O proj
current_ms_metadata = get_multistream_comm_context()
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
if current_ms_metadata is None:
maybe_npu_prefetch(inputs=self.o_proj.weight,
dependency=o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch)
output[...] = self.o_proj(o_proj_input,
is_prefill=prefill_preprocess_res
is not None)[0]
else:
with torch.npu.stream(current_ms_metadata.comm_stream):
maybe_npu_prefetch(inputs=self.o_proj.weight,
dependency=o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch)
output[...] = self.o_proj(o_proj_input,
is_prefill=prefill_preprocess_res
is not None)[0]
current_ms_metadata.after_comm_event.record()
o_proj_input[num_decode_tokens:num_actual_tokens] = output_prefill
# O proj
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
maybe_npu_prefetch(inputs=self.o_proj.weight,
dependency=o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch)
output[...] = self.o_proj(o_proj_input,
is_prefill=prefill_preprocess_res
is not None)[0]
del o_proj_input
has_prefill = attn_metadata.num_prefills > 0
@@ -1719,18 +1675,9 @@ class AscendMLAImpl(MLAAttentionImpl):
attn_out_g, attn_lse_g, attn_out_l, attn_lse_l,
seq_mask_pcp[:, i])
attn_output = attn_out_g
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is None:
return self._v_up_proj(attn_output)
else:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
current_ms_metadata.before_comm_event.wait()
return self._v_up_proj(attn_output)
# TODO use update op to replace this
return self._v_up_proj(attn_output)
# TODO use update op to replace this
def _update_out_and_lse(
self,
out: torch.Tensor,