[main] remove dbo code (#3712)
### What this PR does / why we need it?
Remove codes of dbo.
Currently, vLLM has supported dbo with pr:
https://github.com/vllm-project/vllm/pull/23693.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0rc3
- vLLM main:
17c540a993
Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
@@ -264,8 +264,7 @@ class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
slot_mapping=slot_mapping,
|
||||
attn_mask=attn_mask,
|
||||
attn_state=attn_state,
|
||||
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp)
|
||||
attn_state=attn_state)
|
||||
return attn_metadata
|
||||
|
||||
|
||||
|
||||
@@ -20,9 +20,6 @@ from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||
from vllm_ascend.multistream.context import get_multistream_comm_context
|
||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata,
|
||||
npu_stream_switch, npu_wait_tensor)
|
||||
@@ -141,7 +138,6 @@ class AscendMLATorchairMetadata:
|
||||
|
||||
decode: Optional[AscendMLATorchairDecodeMetadata] = None
|
||||
prefill: Optional[AscendMLATorchairPrefillMetadata] = None
|
||||
enable_dbo_across_dp: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
pass
|
||||
@@ -152,17 +148,6 @@ class AscendMLATorchairMetadata:
|
||||
# f"Only {supported_head_sizes} are supported for head_dim,",
|
||||
# f"received {self.head_dim}.")
|
||||
|
||||
def split_metadata_for_multistream(
|
||||
self,
|
||||
ms_split_config: MSAttentionMetadataSplitConfig,
|
||||
) -> list["AscendMLATorchairMetadata"]:
|
||||
"""Split metadata for multi-stream with AscendMLATorchairMetadata"""
|
||||
return model_input_split_v1_mla_attn(
|
||||
ms_split_config=ms_split_config,
|
||||
attn_metadata=self,
|
||||
_metadata_cls=AscendMLATorchairMetadata,
|
||||
)
|
||||
|
||||
|
||||
M = TypeVar("M", bound=AscendMLATorchairMetadata)
|
||||
|
||||
@@ -576,7 +561,6 @@ class AscendMLATorchairMetadataBuilder:
|
||||
query_start_loc=query_start_loc,
|
||||
block_tables=block_table,
|
||||
seq_lens=seq_lens,
|
||||
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
|
||||
)
|
||||
|
||||
def pad_actual_seq_len_q(self, num_reqs_pad_size, num_reqs,
|
||||
@@ -1072,15 +1056,8 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
|
||||
context_lens=attn_metadata.decode.seq_lens, # type:ignore
|
||||
mla_vheadsize=self.kv_lora_rank,
|
||||
out=attn_output)
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
if current_ms_metadata is None:
|
||||
return self._v_up_proj_and_o_proj(attn_output,
|
||||
enable_multistream_mla)
|
||||
else:
|
||||
current_ms_metadata.before_comm_event.record()
|
||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||
current_ms_metadata.before_comm_event.wait()
|
||||
return self._v_up_proj_and_o_proj(attn_output)
|
||||
|
||||
return self._v_up_proj_and_o_proj(attn_output, enable_multistream_mla)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -1248,14 +1225,7 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
|
||||
prefill_k_c_normed,
|
||||
prefill_k_pe, kv_cache,
|
||||
attn_metadata)
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
if current_ms_metadata is not None:
|
||||
current_ms_metadata.before_comm_event.record()
|
||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||
current_ms_metadata.before_comm_event.wait()
|
||||
o_proj_input[num_decode_tokens:] = output_prefill
|
||||
else:
|
||||
o_proj_input[num_decode_tokens:] = output_prefill
|
||||
o_proj_input[num_decode_tokens:] = output_prefill
|
||||
|
||||
if has_decode:
|
||||
if self.running_in_graph:
|
||||
@@ -1269,35 +1239,19 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
|
||||
decode_k_nope,
|
||||
decode_k_pe, kv_cache,
|
||||
attn_metadata)
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
if current_ms_metadata is not None:
|
||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||
o_proj_input[:num_decode_tokens] = output_decode
|
||||
else:
|
||||
o_proj_input[:num_decode_tokens] = output_decode
|
||||
o_proj_input[:num_decode_tokens] = output_decode
|
||||
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
|
||||
if current_ms_metadata is None:
|
||||
maybe_npu_prefetch(self.o_proj.weight,
|
||||
o_proj_input,
|
||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||
enabled=enable_multistream_mla)
|
||||
|
||||
output[...] = self.o_proj(
|
||||
o_proj_input,
|
||||
is_prefill=True,
|
||||
is_force_scatter=self.enable_shared_expert_dp)[0]
|
||||
else:
|
||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||
maybe_npu_prefetch(self.o_proj.weight,
|
||||
o_proj_input,
|
||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||
enabled=enable_multistream_mla)
|
||||
output[...] = self.o_proj(
|
||||
o_proj_input,
|
||||
is_prefill=True,
|
||||
is_force_scatter=self.enable_shared_expert_dp)[0]
|
||||
current_ms_metadata.after_comm_event.record()
|
||||
maybe_npu_prefetch(self.o_proj.weight,
|
||||
o_proj_input,
|
||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||
enabled=enable_multistream_mla)
|
||||
|
||||
output[...] = self.o_proj(
|
||||
o_proj_input,
|
||||
is_prefill=True,
|
||||
is_force_scatter=self.enable_shared_expert_dp)[0]
|
||||
|
||||
del o_proj_input
|
||||
return output_padded
|
||||
|
||||
@@ -110,30 +110,28 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
self.mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
|
||||
|
||||
def _sync_metadata_across_dp(
|
||||
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
|
||||
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
|
||||
self, num_tokens: int,
|
||||
with_prefill: bool) -> tuple[int, Optional[torch.Tensor], bool]:
|
||||
"""Override from NPUModelRunner to pad num_tokens"""
|
||||
if self.enable_shared_expert_dp:
|
||||
# Padding is not required for shared_expert_dp cases in eager mode.
|
||||
return num_tokens, None, with_prefill, enable_dbo
|
||||
return num_tokens, None, with_prefill
|
||||
if self.dp_size == 1:
|
||||
if not with_prefill:
|
||||
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
|
||||
num_tokens)
|
||||
return maybe_padded_num_tokens, None, with_prefill, enable_dbo
|
||||
return num_tokens, None, with_prefill, enable_dbo
|
||||
return maybe_padded_num_tokens, None, with_prefill
|
||||
return num_tokens, None, with_prefill
|
||||
|
||||
num_tokens_across_dp = torch.zeros(self.dp_size + 2,
|
||||
num_tokens_across_dp = torch.zeros(self.dp_size + 1,
|
||||
dtype=torch.int32,
|
||||
device="npu")
|
||||
num_tokens_across_dp[self.dp_rank] = num_tokens
|
||||
num_tokens_across_dp[-2] = int(with_prefill)
|
||||
num_tokens_across_dp[-1] = int(not enable_dbo)
|
||||
num_tokens_across_dp[-1] = int(with_prefill)
|
||||
dist.all_reduce(num_tokens_across_dp,
|
||||
group=get_dp_group().device_group)
|
||||
with_prefill = bool(num_tokens_across_dp[-2])
|
||||
enable_dbo = not bool(num_tokens_across_dp[-1])
|
||||
num_tokens_across_dp = num_tokens_across_dp[:-2]
|
||||
with_prefill = bool(num_tokens_across_dp[-1])
|
||||
num_tokens_across_dp = num_tokens_across_dp[:-1]
|
||||
|
||||
if not with_prefill:
|
||||
max_num_token = num_tokens_across_dp.max().item()
|
||||
@@ -146,7 +144,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
else:
|
||||
maybe_padded_num_tokens = num_tokens
|
||||
|
||||
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo
|
||||
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill
|
||||
|
||||
def _build_dummy_attn_metadata(
|
||||
self,
|
||||
|
||||
@@ -21,8 +21,6 @@ from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
|
||||
from vllm_ascend.utils import is_enable_nz
|
||||
from vllm_ascend.worker.npu_input_batch import InputBatch
|
||||
@@ -141,7 +139,6 @@ class AscendSFATorchairMetadata:
|
||||
|
||||
decode: Optional[AscendSFATorchairDecodeMetadata] = None
|
||||
prefill: Optional[AscendSFATorchairPrefillMetadata] = None
|
||||
enable_dbo_across_dp: bool = False
|
||||
is_prefill: bool = False
|
||||
is_decode: bool = False
|
||||
|
||||
@@ -154,17 +151,6 @@ class AscendSFATorchairMetadata:
|
||||
# f"Only {supported_head_sizes} are supported for head_dim,",
|
||||
# f"received {self.head_dim}.")
|
||||
|
||||
def split_metadata_for_multistream(
|
||||
self,
|
||||
ms_split_config: MSAttentionMetadataSplitConfig,
|
||||
) -> list["AscendSFATorchairMetadata"]:
|
||||
"""Split metadata for multi-stream with AscendSFATorchairMetadata"""
|
||||
return model_input_split_v1_mla_attn(
|
||||
ms_split_config=ms_split_config,
|
||||
attn_metadata=self,
|
||||
_metadata_cls=AscendSFATorchairMetadata,
|
||||
)
|
||||
|
||||
|
||||
M = TypeVar("M", bound=AscendSFATorchairMetadata)
|
||||
|
||||
@@ -616,7 +602,6 @@ class AscendSFATorchairMetadataBuilder:
|
||||
query_start_loc=query_start_loc,
|
||||
block_tables=block_table,
|
||||
seq_lens=seq_lens,
|
||||
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
|
||||
is_prefill=is_prefill,
|
||||
is_decode=is_decode)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user