[main] remove dbo code (#3712)
### What this PR does / why we need it?
Remove codes of dbo.
Currently, vLLM has supported dbo with pr:
https://github.com/vllm-project/vllm/pull/23693.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0rc3
- vLLM main:
17c540a993
Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
@@ -210,9 +210,6 @@ class AscendMetadata:
|
||||
# (num_tokens,)
|
||||
slot_mapping: torch.Tensor = None
|
||||
|
||||
# *************************** Other Properties *************************** #
|
||||
enable_dbo_across_dp: bool = False
|
||||
|
||||
prefill: Optional[AscendMetadataForPrefill] = None
|
||||
|
||||
decode_meta: Optional[AscendMetadataForDecode] = None
|
||||
@@ -371,7 +368,6 @@ class AscendAttentionMetadataBuilder:
|
||||
slot_mapping=slot_mapping,
|
||||
attn_mask=attn_mask,
|
||||
attn_state=attn_state,
|
||||
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
|
||||
num_prefills=num_prefills,
|
||||
num_decodes=num_decodes,
|
||||
prefill=prefill_metadata,
|
||||
|
||||
@@ -36,9 +36,6 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
trans_rope_weight, transdata,
|
||||
wait_for_kv_layer_from_connector)
|
||||
from vllm_ascend.compilation.acl_graph import get_graph_params
|
||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||
from vllm_ascend.multistream.context import get_multistream_comm_context
|
||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
@@ -184,7 +181,6 @@ class AscendMLAMetadata:
|
||||
|
||||
decode: Optional[AscendMLADecodeMetadata] = None
|
||||
prefill: Optional[AscendMLAPrefillMetadata] = None
|
||||
enable_dbo_across_dp: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
pass
|
||||
@@ -195,17 +191,6 @@ class AscendMLAMetadata:
|
||||
# f"Only {supported_head_sizes} are supported for head_dim,",
|
||||
# f"received {self.head_dim}.")
|
||||
|
||||
def split_metadata_for_multistream(
|
||||
self,
|
||||
ms_split_config: MSAttentionMetadataSplitConfig,
|
||||
) -> list["AscendMLAMetadata"]:
|
||||
"""Split metadata for multi-stream with AscendMLAMetadata"""
|
||||
return model_input_split_v1_mla_attn(
|
||||
ms_split_config=ms_split_config,
|
||||
attn_metadata=self,
|
||||
_metadata_cls=AscendMLAMetadata,
|
||||
)
|
||||
|
||||
|
||||
M = TypeVar("M", bound=AscendMLAMetadata)
|
||||
|
||||
@@ -538,7 +523,6 @@ class AscendMLAMetadataBuilder:
|
||||
query_start_loc=query_start_loc,
|
||||
block_tables=block_table,
|
||||
seq_lens=seq_lens,
|
||||
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
|
||||
)
|
||||
|
||||
def build_for_graph_capture(
|
||||
@@ -1158,14 +1142,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
else:
|
||||
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||
q_nope, k_nope, k_nope, **common_kwargs)
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
if current_ms_metadata is None:
|
||||
return self._v_up_proj(attn_output)
|
||||
else:
|
||||
current_ms_metadata.before_comm_event.record()
|
||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||
current_ms_metadata.before_comm_event.wait()
|
||||
return self._v_up_proj(attn_output)
|
||||
|
||||
return self._v_up_proj(attn_output)
|
||||
|
||||
def _mla_decode_preprocess(self, hidden_states, kv_cache, attn_metadata):
|
||||
bsz = attn_metadata.num_decode_tokens
|
||||
@@ -1423,13 +1401,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
decode_preprocess_res.ql_nope, decode_preprocess_res.q_pe,
|
||||
decode_preprocess_res.k_nope, decode_preprocess_res.k_pe,
|
||||
kv_cache[0].shape[1], attn_metadata)
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
if current_ms_metadata is not None:
|
||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||
o_proj_input[:num_decode_tokens] = output_decode
|
||||
current_ms_metadata.after_comm_event.record()
|
||||
else:
|
||||
o_proj_input[:num_decode_tokens] = output_decode
|
||||
|
||||
o_proj_input[:num_decode_tokens] = output_decode
|
||||
|
||||
if prefill_preprocess_res is not None:
|
||||
# FIX: aicore move should be also placed on the comm stream in dbo,
|
||||
@@ -1445,36 +1418,19 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
prefill_preprocess_res.q_nope, prefill_preprocess_res.q_pe,
|
||||
prefill_preprocess_res.k_nope, prefill_preprocess_res.k_pe,
|
||||
prefill_preprocess_res.value, kv_cache, attn_metadata)
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
if current_ms_metadata is not None:
|
||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||
o_proj_input[num_decode_tokens:] = output_prefill
|
||||
current_ms_metadata.after_comm_event.record()
|
||||
else:
|
||||
o_proj_input[
|
||||
num_decode_tokens:num_actual_tokens] = output_prefill
|
||||
# O proj
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
|
||||
if current_ms_metadata is None:
|
||||
maybe_npu_prefetch(inputs=self.o_proj.weight,
|
||||
dependency=o_proj_input,
|
||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||
enabled=self.enable_prefetch)
|
||||
|
||||
output[...] = self.o_proj(o_proj_input,
|
||||
is_prefill=prefill_preprocess_res
|
||||
is not None)[0]
|
||||
else:
|
||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||
maybe_npu_prefetch(inputs=self.o_proj.weight,
|
||||
dependency=o_proj_input,
|
||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||
enabled=self.enable_prefetch)
|
||||
output[...] = self.o_proj(o_proj_input,
|
||||
is_prefill=prefill_preprocess_res
|
||||
is not None)[0]
|
||||
current_ms_metadata.after_comm_event.record()
|
||||
o_proj_input[num_decode_tokens:num_actual_tokens] = output_prefill
|
||||
# O proj
|
||||
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
|
||||
maybe_npu_prefetch(inputs=self.o_proj.weight,
|
||||
dependency=o_proj_input,
|
||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||
enabled=self.enable_prefetch)
|
||||
|
||||
output[...] = self.o_proj(o_proj_input,
|
||||
is_prefill=prefill_preprocess_res
|
||||
is not None)[0]
|
||||
|
||||
del o_proj_input
|
||||
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
@@ -1719,18 +1675,9 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
attn_out_g, attn_lse_g, attn_out_l, attn_lse_l,
|
||||
seq_mask_pcp[:, i])
|
||||
attn_output = attn_out_g
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
if current_ms_metadata is None:
|
||||
return self._v_up_proj(attn_output)
|
||||
else:
|
||||
current_ms_metadata.before_comm_event.record()
|
||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||
current_ms_metadata.before_comm_event.wait()
|
||||
return self._v_up_proj(attn_output)
|
||||
|
||||
|
||||
# TODO use update op to replace this
|
||||
return self._v_up_proj(attn_output)
|
||||
|
||||
# TODO use update op to replace this
|
||||
def _update_out_and_lse(
|
||||
self,
|
||||
out: torch.Tensor,
|
||||
|
||||
@@ -17,11 +17,8 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||
from vllm_ascend.worker.npu_input_batch import InputBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -138,7 +135,6 @@ class AscendSFAMetadata:
|
||||
|
||||
decode: Optional[AscendSFADecodeMetadata] = None
|
||||
prefill: Optional[AscendSFAPrefillMetadata] = None
|
||||
enable_dbo_across_dp: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
pass
|
||||
@@ -149,17 +145,6 @@ class AscendSFAMetadata:
|
||||
# f"Only {supported_head_sizes} are supported for head_dim,",
|
||||
# f"received {self.head_dim}.")
|
||||
|
||||
def split_metadata_for_multistream(
|
||||
self,
|
||||
ms_split_config: MSAttentionMetadataSplitConfig,
|
||||
) -> list["AscendSFAMetadata"]:
|
||||
"""Split metadata for multi-stream with AscendSFAMetadata"""
|
||||
return model_input_split_v1_mla_attn(
|
||||
ms_split_config=ms_split_config,
|
||||
attn_metadata=self,
|
||||
_metadata_cls=AscendMLAMetadata,
|
||||
)
|
||||
|
||||
|
||||
M = TypeVar("M", bound=AscendSFAMetadata)
|
||||
|
||||
@@ -434,7 +419,6 @@ class AscendSFAMetadataBuilder:
|
||||
query_start_loc=query_start_loc,
|
||||
block_tables=block_table,
|
||||
seq_lens=seq_lens,
|
||||
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -91,8 +91,6 @@ class AscendCommonAttentionMetadata:
|
||||
|
||||
attn_state: Any = None
|
||||
|
||||
enable_dbo_across_dp: bool = False
|
||||
|
||||
is_only_prefill: bool = False
|
||||
|
||||
graph_pad_size: int = -1
|
||||
|
||||
Reference in New Issue
Block a user