[main] remove dbo code (#3712)
### What this PR does / why we need it?
Remove codes of dbo.
Currently, vLLM has supported dbo with pr:
https://github.com/vllm-project/vllm/pull/23693.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0rc3
- vLLM main:
17c540a993
Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
@@ -36,9 +36,6 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
trans_rope_weight, transdata,
|
||||
wait_for_kv_layer_from_connector)
|
||||
from vllm_ascend.compilation.acl_graph import get_graph_params
|
||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||
from vllm_ascend.multistream.context import get_multistream_comm_context
|
||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
@@ -184,7 +181,6 @@ class AscendMLAMetadata:
|
||||
|
||||
decode: Optional[AscendMLADecodeMetadata] = None
|
||||
prefill: Optional[AscendMLAPrefillMetadata] = None
|
||||
enable_dbo_across_dp: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
pass
|
||||
@@ -195,17 +191,6 @@ class AscendMLAMetadata:
|
||||
# f"Only {supported_head_sizes} are supported for head_dim,",
|
||||
# f"received {self.head_dim}.")
|
||||
|
||||
def split_metadata_for_multistream(
|
||||
self,
|
||||
ms_split_config: MSAttentionMetadataSplitConfig,
|
||||
) -> list["AscendMLAMetadata"]:
|
||||
"""Split metadata for multi-stream with AscendMLAMetadata"""
|
||||
return model_input_split_v1_mla_attn(
|
||||
ms_split_config=ms_split_config,
|
||||
attn_metadata=self,
|
||||
_metadata_cls=AscendMLAMetadata,
|
||||
)
|
||||
|
||||
|
||||
M = TypeVar("M", bound=AscendMLAMetadata)
|
||||
|
||||
@@ -538,7 +523,6 @@ class AscendMLAMetadataBuilder:
|
||||
query_start_loc=query_start_loc,
|
||||
block_tables=block_table,
|
||||
seq_lens=seq_lens,
|
||||
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
|
||||
)
|
||||
|
||||
def build_for_graph_capture(
|
||||
@@ -1158,14 +1142,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
else:
|
||||
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||
q_nope, k_nope, k_nope, **common_kwargs)
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
if current_ms_metadata is None:
|
||||
return self._v_up_proj(attn_output)
|
||||
else:
|
||||
current_ms_metadata.before_comm_event.record()
|
||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||
current_ms_metadata.before_comm_event.wait()
|
||||
return self._v_up_proj(attn_output)
|
||||
|
||||
return self._v_up_proj(attn_output)
|
||||
|
||||
def _mla_decode_preprocess(self, hidden_states, kv_cache, attn_metadata):
|
||||
bsz = attn_metadata.num_decode_tokens
|
||||
@@ -1423,13 +1401,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
decode_preprocess_res.ql_nope, decode_preprocess_res.q_pe,
|
||||
decode_preprocess_res.k_nope, decode_preprocess_res.k_pe,
|
||||
kv_cache[0].shape[1], attn_metadata)
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
if current_ms_metadata is not None:
|
||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||
o_proj_input[:num_decode_tokens] = output_decode
|
||||
current_ms_metadata.after_comm_event.record()
|
||||
else:
|
||||
o_proj_input[:num_decode_tokens] = output_decode
|
||||
|
||||
o_proj_input[:num_decode_tokens] = output_decode
|
||||
|
||||
if prefill_preprocess_res is not None:
|
||||
# FIX: aicore move should be also placed on the comm stream in dbo,
|
||||
@@ -1445,36 +1418,19 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
prefill_preprocess_res.q_nope, prefill_preprocess_res.q_pe,
|
||||
prefill_preprocess_res.k_nope, prefill_preprocess_res.k_pe,
|
||||
prefill_preprocess_res.value, kv_cache, attn_metadata)
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
if current_ms_metadata is not None:
|
||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||
o_proj_input[num_decode_tokens:] = output_prefill
|
||||
current_ms_metadata.after_comm_event.record()
|
||||
else:
|
||||
o_proj_input[
|
||||
num_decode_tokens:num_actual_tokens] = output_prefill
|
||||
# O proj
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
|
||||
if current_ms_metadata is None:
|
||||
maybe_npu_prefetch(inputs=self.o_proj.weight,
|
||||
dependency=o_proj_input,
|
||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||
enabled=self.enable_prefetch)
|
||||
|
||||
output[...] = self.o_proj(o_proj_input,
|
||||
is_prefill=prefill_preprocess_res
|
||||
is not None)[0]
|
||||
else:
|
||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||
maybe_npu_prefetch(inputs=self.o_proj.weight,
|
||||
dependency=o_proj_input,
|
||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||
enabled=self.enable_prefetch)
|
||||
output[...] = self.o_proj(o_proj_input,
|
||||
is_prefill=prefill_preprocess_res
|
||||
is not None)[0]
|
||||
current_ms_metadata.after_comm_event.record()
|
||||
o_proj_input[num_decode_tokens:num_actual_tokens] = output_prefill
|
||||
# O proj
|
||||
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
|
||||
maybe_npu_prefetch(inputs=self.o_proj.weight,
|
||||
dependency=o_proj_input,
|
||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||
enabled=self.enable_prefetch)
|
||||
|
||||
output[...] = self.o_proj(o_proj_input,
|
||||
is_prefill=prefill_preprocess_res
|
||||
is not None)[0]
|
||||
|
||||
del o_proj_input
|
||||
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
@@ -1719,18 +1675,9 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
attn_out_g, attn_lse_g, attn_out_l, attn_lse_l,
|
||||
seq_mask_pcp[:, i])
|
||||
attn_output = attn_out_g
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
if current_ms_metadata is None:
|
||||
return self._v_up_proj(attn_output)
|
||||
else:
|
||||
current_ms_metadata.before_comm_event.record()
|
||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||
current_ms_metadata.before_comm_event.wait()
|
||||
return self._v_up_proj(attn_output)
|
||||
|
||||
|
||||
# TODO use update op to replace this
|
||||
return self._v_up_proj(attn_output)
|
||||
|
||||
# TODO use update op to replace this
|
||||
def _update_out_and_lse(
|
||||
self,
|
||||
out: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user