[main] remove dbo code (#3712)

### What this PR does / why we need it?
Remove codes of dbo.
Currently, vLLM has supported dbo with pr:
https://github.com/vllm-project/vllm/pull/23693.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.11.0rc3
- vLLM main:
17c540a993

Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
zzzzwwjj
2025-10-25 15:53:01 +08:00
committed by GitHub
parent d9cdc65854
commit e5676fc36e
26 changed files with 69 additions and 1588 deletions

View File

@@ -210,9 +210,6 @@ class AscendMetadata:
# (num_tokens,)
slot_mapping: torch.Tensor = None
# *************************** Other Properties *************************** #
enable_dbo_across_dp: bool = False
prefill: Optional[AscendMetadataForPrefill] = None
decode_meta: Optional[AscendMetadataForDecode] = None
@@ -371,7 +368,6 @@ class AscendAttentionMetadataBuilder:
slot_mapping=slot_mapping,
attn_mask=attn_mask,
attn_state=attn_state,
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
num_prefills=num_prefills,
num_decodes=num_decodes,
prefill=prefill_metadata,

View File

@@ -36,9 +36,6 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
trans_rope_weight, transdata,
wait_for_kv_layer_from_connector)
from vllm_ascend.compilation.acl_graph import get_graph_params
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.context import get_multistream_comm_context
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
@@ -184,7 +181,6 @@ class AscendMLAMetadata:
decode: Optional[AscendMLADecodeMetadata] = None
prefill: Optional[AscendMLAPrefillMetadata] = None
enable_dbo_across_dp: bool = False
def __post_init__(self):
pass
@@ -195,17 +191,6 @@ class AscendMLAMetadata:
# f"Only {supported_head_sizes} are supported for head_dim,",
# f"received {self.head_dim}.")
def split_metadata_for_multistream(
self,
ms_split_config: MSAttentionMetadataSplitConfig,
) -> list["AscendMLAMetadata"]:
"""Split metadata for multi-stream with AscendMLAMetadata"""
return model_input_split_v1_mla_attn(
ms_split_config=ms_split_config,
attn_metadata=self,
_metadata_cls=AscendMLAMetadata,
)
M = TypeVar("M", bound=AscendMLAMetadata)
@@ -538,7 +523,6 @@ class AscendMLAMetadataBuilder:
query_start_loc=query_start_loc,
block_tables=block_table,
seq_lens=seq_lens,
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
)
def build_for_graph_capture(
@@ -1158,14 +1142,8 @@ class AscendMLAImpl(MLAAttentionImpl):
else:
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
q_nope, k_nope, k_nope, **common_kwargs)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is None:
return self._v_up_proj(attn_output)
else:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
current_ms_metadata.before_comm_event.wait()
return self._v_up_proj(attn_output)
return self._v_up_proj(attn_output)
def _mla_decode_preprocess(self, hidden_states, kv_cache, attn_metadata):
bsz = attn_metadata.num_decode_tokens
@@ -1423,13 +1401,8 @@ class AscendMLAImpl(MLAAttentionImpl):
decode_preprocess_res.ql_nope, decode_preprocess_res.q_pe,
decode_preprocess_res.k_nope, decode_preprocess_res.k_pe,
kv_cache[0].shape[1], attn_metadata)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is not None:
with torch.npu.stream(current_ms_metadata.comm_stream):
o_proj_input[:num_decode_tokens] = output_decode
current_ms_metadata.after_comm_event.record()
else:
o_proj_input[:num_decode_tokens] = output_decode
o_proj_input[:num_decode_tokens] = output_decode
if prefill_preprocess_res is not None:
# FIX: aicore move should be also placed on the comm stream in dbo,
@@ -1445,36 +1418,19 @@ class AscendMLAImpl(MLAAttentionImpl):
prefill_preprocess_res.q_nope, prefill_preprocess_res.q_pe,
prefill_preprocess_res.k_nope, prefill_preprocess_res.k_pe,
prefill_preprocess_res.value, kv_cache, attn_metadata)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is not None:
with torch.npu.stream(current_ms_metadata.comm_stream):
o_proj_input[num_decode_tokens:] = output_prefill
current_ms_metadata.after_comm_event.record()
else:
o_proj_input[
num_decode_tokens:num_actual_tokens] = output_prefill
# O proj
current_ms_metadata = get_multistream_comm_context()
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
if current_ms_metadata is None:
maybe_npu_prefetch(inputs=self.o_proj.weight,
dependency=o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch)
output[...] = self.o_proj(o_proj_input,
is_prefill=prefill_preprocess_res
is not None)[0]
else:
with torch.npu.stream(current_ms_metadata.comm_stream):
maybe_npu_prefetch(inputs=self.o_proj.weight,
dependency=o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch)
output[...] = self.o_proj(o_proj_input,
is_prefill=prefill_preprocess_res
is not None)[0]
current_ms_metadata.after_comm_event.record()
o_proj_input[num_decode_tokens:num_actual_tokens] = output_prefill
# O proj
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
maybe_npu_prefetch(inputs=self.o_proj.weight,
dependency=o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch)
output[...] = self.o_proj(o_proj_input,
is_prefill=prefill_preprocess_res
is not None)[0]
del o_proj_input
has_prefill = attn_metadata.num_prefills > 0
@@ -1719,18 +1675,9 @@ class AscendMLAImpl(MLAAttentionImpl):
attn_out_g, attn_lse_g, attn_out_l, attn_lse_l,
seq_mask_pcp[:, i])
attn_output = attn_out_g
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is None:
return self._v_up_proj(attn_output)
else:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
current_ms_metadata.before_comm_event.wait()
return self._v_up_proj(attn_output)
# TODO use update op to replace this
return self._v_up_proj(attn_output)
# TODO use update op to replace this
def _update_out_and_lse(
self,
out: torch.Tensor,

View File

@@ -17,11 +17,8 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
split_decodes_and_prefills)
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.worker.npu_input_batch import InputBatch
if TYPE_CHECKING:
@@ -138,7 +135,6 @@ class AscendSFAMetadata:
decode: Optional[AscendSFADecodeMetadata] = None
prefill: Optional[AscendSFAPrefillMetadata] = None
enable_dbo_across_dp: bool = False
def __post_init__(self):
pass
@@ -149,17 +145,6 @@ class AscendSFAMetadata:
# f"Only {supported_head_sizes} are supported for head_dim,",
# f"received {self.head_dim}.")
def split_metadata_for_multistream(
self,
ms_split_config: MSAttentionMetadataSplitConfig,
) -> list["AscendSFAMetadata"]:
"""Split metadata for multi-stream with AscendSFAMetadata"""
return model_input_split_v1_mla_attn(
ms_split_config=ms_split_config,
attn_metadata=self,
_metadata_cls=AscendMLAMetadata,
)
M = TypeVar("M", bound=AscendSFAMetadata)
@@ -434,7 +419,6 @@ class AscendSFAMetadataBuilder:
query_start_loc=query_start_loc,
block_tables=block_table,
seq_lens=seq_lens,
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
)

View File

@@ -91,8 +91,6 @@ class AscendCommonAttentionMetadata:
attn_state: Any = None
enable_dbo_across_dp: bool = False
is_only_prefill: bool = False
graph_pad_size: int = -1

View File

@@ -82,9 +82,6 @@ env_variables: Dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP":
lambda: bool(int(os.getenv("VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP", '0'))
),
# Whether to enable DBO feature for deepseek model.
"VLLM_ASCEND_ENABLE_DBO":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_DBO", '0'))),
# Whether to enable the model execute time observe profile. Disable it when
# running vllm ascend in production environment.
"VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE":

View File

@@ -1,29 +0,0 @@
from dataclasses import dataclass
from enum import Enum
class MSEventKey(Enum):
ATTN_COM_FINISH = 0
ATTN_AR_FINISH = 1
FFN_COM_FINISH = 2
FFN_AR_FINISH = 3
# events for MOE dispatch and combine
MOE_BEFORE_COMM = 4
MOE_AFTER_COMM = 5
# events for shared expert
MOE_SE_COMM_FINISH = 6
MOE_SE_COMP_FINISH = 7
MOE_GATE_FINISH = 8
@dataclass
class MSAttentionMetadataSplitConfig:
"""
micro batch split config for split attention metadata
"""
# micro batch num
num_micro_batches: int = 2
# split micro batches only when total tokens >= min_total_tokens_to_split
min_total_tokens_to_split: int = 256
# split micro batches only when prefill tokens >= min_prefill_tokens_to_split
min_prefill_tokens_to_split: int = 64

View File

@@ -1,67 +0,0 @@
from contextlib import contextmanager
from typing import Any
_ms_comm_context: Any = None
_cur_micro_batch_num: int = -1
_ms_layer_index_context: int = -1
_ms_metadata_context: Any = None
_ms_attn_metadata_context: Any = None
def set_multistream_layer_context(start_layer: int, ms_metadata: Any,
attn_metadata: Any):
"""
set multistream layer context before transformer layers
"""
global _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context
_ms_layer_index_context = start_layer
_ms_metadata_context = ms_metadata
_ms_attn_metadata_context = attn_metadata
def reset_multistream_layer_context():
"""
reset multistream layer context
"""
global _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context
_ms_layer_index_context = -1
_ms_metadata_context = None
_ms_attn_metadata_context = None
def get_multistream_layer_context():
"""
get multistream layer context
"""
return _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context
def advance_step_multistream_layer_context():
"""
advance multistream layer index context
"""
global _ms_layer_index_context
_ms_layer_index_context += 1
def get_multistream_comm_context() -> Any:
"""Get the current comm forward context."""
return _ms_comm_context
def get_multistream_microbatch_context() -> int:
return _cur_micro_batch_num
@contextmanager
def set_multistream_context(context: Any, micro_batch_num: int):
"""A context manager that stores the current comm forward context,
can be attention metadata, etc."""
global _ms_comm_context, _cur_micro_batch_num
_ms_comm_context = context
_cur_micro_batch_num = micro_batch_num
try:
yield
finally:
_ms_comm_context = None
_cur_micro_batch_num = -1

View File

@@ -1,22 +0,0 @@
from .context import (get_multistream_layer_context,
get_multistream_microbatch_context)
# vllm v1 use get_forward_context to get the attn_metadata,
# we can use this decorator to update the attn metadata
def set_multistream_support():
def decorator(func):
def wrapper():
context = func()
layer_index, ms_metadata, attn_metadata = get_multistream_layer_context(
)
micro_batch_num = get_multistream_microbatch_context()
if layer_index != -1 and micro_batch_num != -1:
context.attn_metadata = attn_metadata[micro_batch_num]
return context
return wrapper
return decorator

View File

@@ -1,61 +0,0 @@
from typing import List, Optional, Tuple, Union
import torch
from vllm.forward_context import get_forward_context
from .base import MSEventKey
from .context import (get_multistream_layer_context,
reset_multistream_layer_context,
set_multistream_layer_context)
from .metadata import MultiStreamMetadata
class MultiStreamPreTransformerLayer(torch.nn.Module):
def __init__(self, multistream_metadata: MultiStreamMetadata):
super().__init__()
self.multistream_metadata = multistream_metadata
def forward(
self,
intput_tensors: List[torch.Tensor],
):
attn_metadata = get_forward_context().attn_metadata
if self.multistream_metadata is None or attn_metadata is None:
set_multistream_layer_context(-1, None, None)
return attn_metadata, intput_tensors
# TODO add attn_metadata management
do_ms, attn_metadata, intput_tensors, _ = self.multistream_metadata.split_micro_batch(
attn_metadata, intput_tensors)
if do_ms:
set_multistream_layer_context(
self.multistream_metadata.start_layer,
self.multistream_metadata, attn_metadata)
else:
set_multistream_layer_context(-1, None, None)
return attn_metadata, intput_tensors
class MultiStreamPostTransformerLayer(torch.nn.Module):
def __init__(self, multistream_metadata: MultiStreamMetadata):
super().__init__()
self.multistream_metadata = multistream_metadata
def forward(self,
input_tensors: Union[List[Tuple[torch.Tensor]],
List[torch.Tensor],
List[List[torch.Tensor]]],
wait_layer_index: Optional[int] = None):
if self.multistream_metadata is None or self.multistream_metadata.ms_config is None:
return input_tensors
layer_index, ms_metadata, ms_attn_metadata = get_multistream_layer_context(
)
if layer_index >= 0:
true_wait_layer = self.multistream_metadata.end_layer - 1 if wait_layer_index is None else wait_layer_index
self.multistream_metadata.try_wait_event(
true_wait_layer,
self.multistream_metadata.ms_config.num_micro_batches - 1,
MSEventKey.FFN_AR_FINISH)
reset_multistream_layer_context()
return self.multistream_metadata.merge_micro_batches(input_tensors)

View File

@@ -1,182 +0,0 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
import torch
from vllm.sequence import IntermediateTensors
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
from .base import MSAttentionMetadataSplitConfig, MSEventKey
def split_micro_batches_tensors(input_tensors,
split_index: int,
keys: Optional[List[str]] = None):
if isinstance(input_tensors, list):
micro_batches = []
for tensor in input_tensors:
if tensor is None:
micro_batches.append([None, None])
else:
micro_batches.append(
[tensor[:split_index], tensor[split_index:]])
return micro_batches
elif isinstance(input_tensors, torch.Tensor):
return [input_tensors[:split_index], input_tensors[split_index:]]
elif input_tensors is None:
return [None, None]
elif isinstance(input_tensors, Dict):
assert keys is not None
micro_batches_pre = {}
for key in keys:
micro_batches_pre[key] = input_tensors[key][:split_index]
micro_batches_post = {}
for key in keys:
micro_batches_post[key] = input_tensors[key][split_index:]
return [micro_batches_pre, micro_batches_post]
else:
raise NotImplementedError
@dataclass
class MultiStreamStepMetadata:
comm_stream: torch.npu.Stream = None
before_comm_event: torch.npu.Event = None
after_comm_event: torch.npu.Event = None
@dataclass
class MultiStreamConfig:
"""Controls the behavior of multi-stream models."""
min_total_tokens_to_split: int = 256
min_prefill_tokens_to_split: int = 64
num_micro_batches: int = 2
imbalance_ratio: float = 0.1
class MultiStreamMetadata:
# direct stream
calculate_stream = None
# delay stream
communicate_stream = None
# events
ms_events: Dict[int, Dict[int, Dict[MSEventKey, torch.npu.Event]]] = {}
# multi-stream-flag
enable_multi_stream: bool = False
def __init__(
self,
calculate_stream: torch.npu.Stream,
communicate_stream: torch.npu.Stream,
start_layer: int,
end_layer: int,
event_keys: List[MSEventKey],
multistream_config: Optional[MultiStreamConfig],
causal_lm: bool = True,
):
self.calculate_stream = calculate_stream
self.communicate_stream = communicate_stream
self.start_layer = start_layer
self.end_layer = end_layer
self.ms_config = multistream_config
self.causal_lm = causal_lm
self._build_events(event_keys)
self._build_ms_split_config()
def _build_events(self, event_keys):
if self.ms_config is not None:
for i in range(self.start_layer - 1, self.end_layer):
self.ms_events[i] = {}
for j in range(self.ms_config.num_micro_batches):
self.ms_events[i][j] = {}
for key in event_keys:
self.ms_events[i][j][key] = torch.npu.Event()
def _build_ms_split_config(self):
if self.ms_config is not None:
self.ms_split_config = MSAttentionMetadataSplitConfig(
num_micro_batches=self.ms_config.num_micro_batches,
min_total_tokens_to_split=self.ms_config.
min_total_tokens_to_split,
min_prefill_tokens_to_split=self.ms_config.
min_prefill_tokens_to_split,
)
def try_wait_event(self, layer_index: int, micro_batch_index: int,
event_key: MSEventKey):
self.ms_events[layer_index][micro_batch_index][event_key].wait()
def try_record_event(self, layer_index: int, micro_batch_index: int,
event_key: MSEventKey):
self.ms_events[layer_index][micro_batch_index][event_key].record()
def split_micro_batch(
self,
attn_metadata: "AscendMLAMetadata",
intput_tensors: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
intermediate_tensors_keys: Optional[List[str]] = None,
) -> Tuple[bool, Union[AscendMLAMetadata, List[AscendMLAMetadata]], Union[
List[torch.Tensor], List[List[torch.Tensor]]], Union[
IntermediateTensors, List[IntermediateTensors]]]:
attn_metadata_list = attn_metadata.split_metadata_for_multistream(
self.ms_split_config)
if len(attn_metadata_list) == 1:
return False, attn_metadata_list[
0], intput_tensors, intermediate_tensors
split_index = attn_metadata_list[0].slot_mapping.shape[0]
input_tensors = split_micro_batches_tensors(intput_tensors,
split_index)
if intermediate_tensors is not None:
inter_tensors_list = split_micro_batches_tensors(
intermediate_tensors.tensors, split_index,
intermediate_tensors_keys)
intermediate_tensors = [
IntermediateTensors(inter_tensors)
for inter_tensors in inter_tensors_list
]
return True, attn_metadata_list, input_tensors, intermediate_tensors
def merge_micro_batches(
self, input_tensors: Union[List[torch.Tensor],
List[List[torch.Tensor]]]
) -> List[torch.Tensor]:
if input_tensors is None or isinstance(input_tensors[0], torch.Tensor):
return input_tensors
batch: List[Optional[torch.Tensor]] = []
for tensors in input_tensors:
if tensors is None or tensors[0] is None:
batch.append(None)
else:
batch.append(torch.cat(tensors, dim=0))
return batch
def make_multistream_metadata_ds(
start_layer: int,
end_layer: int,
causal_lm: bool = True,
multistream_config: Optional[MultiStreamConfig] = None,
):
if multistream_config is None:
return None
event_keylist = [
MSEventKey.ATTN_COM_FINISH,
MSEventKey.ATTN_AR_FINISH,
MSEventKey.FFN_COM_FINISH,
MSEventKey.FFN_AR_FINISH,
MSEventKey.MOE_BEFORE_COMM,
MSEventKey.MOE_AFTER_COMM,
MSEventKey.MOE_SE_COMM_FINISH,
MSEventKey.MOE_SE_COMP_FINISH,
MSEventKey.MOE_GATE_FINISH,
]
return MultiStreamMetadata(
calculate_stream=torch.npu.current_stream(),
communicate_stream=torch.npu.Stream(),
start_layer=start_layer,
end_layer=end_layer,
multistream_config=multistream_config,
event_keys=event_keylist,
causal_lm=causal_lm,
)

View File

@@ -1,247 +0,0 @@
from copy import deepcopy
from typing import Any, List, Optional
import numpy as np
import torch
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from .base import MSAttentionMetadataSplitConfig
def compute_split_seq_index(
query_lens: Optional[list[int]],
attn_state: AscendAttentionState,
num_tokens: int,
imbalance_ratio: float = 0.1,
) -> list[int]:
if attn_state != AscendAttentionState.DecodeOnly:
assert query_lens is not None
total_tokens = sum(query_lens)
# the first index in last split
tokens, split_index = 0, 0
for value in query_lens:
tokens += value
split_index += 1
if tokens >= total_tokens // 2:
# check the current split index
if abs(tokens -
total_tokens // 2) < total_tokens * imbalance_ratio:
return [tokens, split_index]
# check the previous split index
elif abs(tokens - total_tokens // 2 -
value) < total_tokens * imbalance_ratio:
return [tokens - value, split_index - 1]
# fail to split if it is imbalanced
# TODO: split tokens in seq
else:
return [0, 0]
else:
tokens = num_tokens // 2
return [tokens, tokens]
return [0, 0]
def split_attn_tensor_type(
input_tensor: torch.Tensor,
index: int,
) -> List[torch.Tensor]:
return [input_tensor[:index], input_tensor[index:]]
def split_attn_int_type(
var: int,
index: int,
) -> List[torch.Tensor]:
return [min(var, index), max(var - index, 0)]
def model_input_split_v1_mla_attn(
attn_metadata,
_metadata_cls,
ms_split_config: MSAttentionMetadataSplitConfig,
) -> List[Any]:
assert 0 < ms_split_config.num_micro_batches < 3
if attn_metadata is None:
return [attn_metadata]
[token_index,
seq_index] = compute_split_seq_index(attn_metadata.query_lens,
attn_metadata.attn_state,
attn_metadata.num_decode_tokens)
if token_index == 0 or seq_index == 0 or seq_index == len(
attn_metadata.query_lens):
return [attn_metadata]
query_start_loc_cpu = np.zeros(shape=(len(attn_metadata.query_lens) + 1, ),
dtype=int)
np.cumsum(attn_metadata.query_lens, out=query_start_loc_cpu[1:])
if attn_metadata.num_prefills > 0:
prefill_query_start_loc = np.zeros(
shape=(len(attn_metadata.prefill.query_lens) + 1, ), dtype=int)
np.cumsum(attn_metadata.prefill.query_lens,
out=prefill_query_start_loc[1:])
# split attn metadata
[slot_mapping_pre,
slot_mapping_post] = split_attn_tensor_type(attn_metadata.slot_mapping,
token_index)
[num_decodes_pre,
num_decodes_post] = split_attn_int_type(attn_metadata.num_decodes,
seq_index)
[num_decode_tokens_pre, num_decode_tokens_post
] = split_attn_int_type(attn_metadata.num_decode_tokens, token_index)
[num_prefills_pre, num_prefills_post
] = split_attn_int_type(attn_metadata.num_prefills,
max(0, seq_index - attn_metadata.num_decodes))
seq_lens = attn_metadata.prefill.seq_lens if attn_metadata.num_prefills > 0 else attn_metadata.decode.seq_lens
[seq_lens_pre, seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index)
query_start_loc_pre = query_start_loc_post = None
if attn_metadata.query_start_loc is not None:
query_start_loc_pre = attn_metadata.query_start_loc[:seq_index + 1]
query_start_loc_post = deepcopy(
attn_metadata.query_start_loc[seq_index:]
) - attn_metadata.query_start_loc[seq_index]
[block_table_pre,
block_table_post] = split_attn_tensor_type(attn_metadata.block_tables,
seq_index)
assert attn_metadata.attn_mask is not None
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache or attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
# the attn_mla kernel in torch npu only accept 128*128 attn mask
attn_mask_pre = attn_mask_post = attn_metadata.attn_mask
attn_state_pre = attn_state_post = attn_metadata.attn_state
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
# should be none in decode only state
attn_mask_pre = attn_mask_post = attn_metadata.attn_mask
attn_state_pre = attn_state_post = AscendAttentionState.DecodeOnly
else:
# chunked prefill
if num_prefills_pre > 0:
attn_state_pre = attn_state_post = AscendAttentionState.ChunkedPrefill
attn_mask_pre = attn_metadata.attn_mask[:token_index, :max(
seq_lens_pre)].contiguous()
attn_state_post = AscendAttentionState.ChunkedPrefill
attn_mask_post = attn_metadata.attn_mask[
token_index:, :max(seq_lens_post)].contiguous()
else:
attn_state_pre = AscendAttentionState.DecodeOnly
attn_mask_pre = None
attn_state_post = AscendAttentionState.ChunkedPrefill
attn_mask_post = attn_metadata.attn_mask[
token_index:, :max(seq_lens_post)].contiguous()
from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata,
AscendMLAPrefillMetadata)
if num_prefills_pre > 0:
# split metadata.prefill
[input_positions_pre, input_positions_post] = split_attn_tensor_type(
attn_metadata.prefill.input_positions,
token_index - attn_metadata.num_decode_tokens)
[block_tables_pre, block_tables_post
] = split_attn_tensor_type(attn_metadata.prefill.block_table,
seq_index - attn_metadata.num_decodes)
[prefill_query_lens_pre, prefill_query_lens_post
] = split_attn_tensor_type(attn_metadata.prefill.query_lens,
seq_index - attn_metadata.num_decodes)
prefill_query_start_loc_pre = attn_metadata.prefill.query_start_loc[:
seq_index
+
1 -
attn_metadata
.
num_decodes]
prefill_query_start_loc_post = deepcopy(
attn_metadata.prefill.query_start_loc[seq_index -
attn_metadata.num_decodes:]
) - attn_metadata.prefill.query_start_loc[seq_index -
attn_metadata.num_decodes]
context_len_pre = seq_lens_pre[attn_metadata.num_decodes:]
context_len_post = seq_lens_post
prefill_max_query_len_pre = max(prefill_query_lens_pre)
prefill_max_query_len_post = max(prefill_query_lens_post)
prefill_pre = AscendMLAPrefillMetadata(
attn_mask=attn_mask_pre,
query_lens=prefill_query_lens_pre,
seq_lens=seq_lens_pre,
query_start_loc=prefill_query_start_loc_pre,
input_positions=input_positions_pre,
context_lens=context_len_pre,
block_table=block_tables_pre,
max_query_len=prefill_max_query_len_pre,
max_seq_lens=context_len_pre.max().item(),
)
prefill_post = AscendMLAPrefillMetadata(
attn_mask=attn_mask_post,
query_lens=prefill_query_lens_post,
seq_lens=seq_lens_post,
query_start_loc=prefill_query_start_loc_post,
input_positions=input_positions_post,
context_lens=context_len_post,
block_table=block_tables_post,
max_query_len=prefill_max_query_len_post,
max_seq_lens=context_len_post.max().item(),
)
decode_pre = attn_metadata.decode
decode_post = None
else:
# prefill is None, split metadata.decode
[input_positions_pre, input_positions_post
] = split_attn_tensor_type(attn_metadata.decode.input_positions,
token_index)
[block_tables_pre, block_tables_post
] = split_attn_tensor_type(attn_metadata.decode.block_table,
seq_index)
[decode_seq_lens_pre,
decode_seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index)
decode_pre = AscendMLADecodeMetadata(
input_positions=input_positions_pre,
block_table=block_tables_pre,
seq_lens=decode_seq_lens_pre,
max_seq_lens=max(decode_seq_lens_pre),
seq_lens_list=decode_seq_lens_pre.tolist(),
)
decode_post = AscendMLADecodeMetadata(
input_positions=input_positions_post,
block_table=block_tables_post,
seq_lens=decode_seq_lens_post,
max_seq_lens=max(decode_seq_lens_post),
seq_lens_list=decode_seq_lens_post.tolist(),
)
prefill_pre = None
prefill_post = attn_metadata.prefill
# construct metadata
from vllm_ascend.attention.mla_v1 import AscendMLAPrefillMetadata
attention_metadata_pre = _metadata_cls(
num_actual_tokens=token_index,
num_input_tokens=token_index,
head_dim=attn_metadata.head_dim,
slot_mapping=slot_mapping_pre,
seq_lens=seq_lens_pre,
query_start_loc=query_start_loc_pre,
block_tables=block_table_pre,
num_decodes=num_decodes_pre,
num_prefills=num_prefills_pre,
num_decode_tokens=num_decode_tokens_pre,
attn_state=attn_state_pre,
attn_mask=attn_mask_pre,
prefill=prefill_pre,
decode=decode_pre,
enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp,
)
attention_metadata_post = _metadata_cls(
num_actual_tokens=attn_metadata.num_actual_tokens - token_index,
num_input_tokens=attn_metadata.num_input_tokens - token_index,
head_dim=attn_metadata.head_dim,
slot_mapping=slot_mapping_post,
seq_lens=seq_lens_post,
query_start_loc=query_start_loc_post,
block_tables=block_table_post,
num_decodes=num_decodes_post,
num_prefills=num_prefills_post,
num_decode_tokens=num_decode_tokens_post,
attn_mask=attn_mask_post,
attn_state=attn_state_post,
prefill=prefill_post,
decode=decode_post,
enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp,
)
return [attention_metadata_pre, attention_metadata_post]

View File

@@ -122,10 +122,11 @@ class MtpProposer(Proposer):
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor=None) -> None:
if not self.torchair_graph_enabled:
# TODO: adapt enable_dbo later
(num_tokens, num_tokens_across_dp, with_prefill,
_) = self.runner._sync_metadata_across_dp(num_tokens,
with_prefill, False)
(
num_tokens,
num_tokens_across_dp,
with_prefill,
) = self.runner._sync_metadata_across_dp(num_tokens, with_prefill)
moe_comm_type = self.runner._select_moe_comm_method(
num_tokens, with_prefill)
@@ -429,10 +430,9 @@ class MtpProposer(Proposer):
if not self.torchair_graph_enabled:
# torch mode need to update num_tokens_across_dp
# TODO: adapt enable_dbo later
(num_input_tokens, num_tokens_across_dp, with_prefill,
_) = self.runner._sync_metadata_across_dp(
num_input_tokens, self.runner.with_prefill, False)
(num_input_tokens, num_tokens_across_dp,
with_prefill) = self.runner._sync_metadata_across_dp(
num_input_tokens, self.runner.with_prefill)
else:
# torchair mode can reuse self.runner.num_tokens_across_dp
num_tokens_across_dp = self.runner.num_tokens_across_dp

View File

@@ -264,8 +264,7 @@ class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder):
max_query_len=common_attn_metadata.max_query_len,
slot_mapping=slot_mapping,
attn_mask=attn_mask,
attn_state=attn_state,
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp)
attn_state=attn_state)
return attn_metadata

View File

@@ -20,9 +20,6 @@ from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
split_decodes_and_prefills)
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.context import get_multistream_comm_context
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata,
npu_stream_switch, npu_wait_tensor)
@@ -141,7 +138,6 @@ class AscendMLATorchairMetadata:
decode: Optional[AscendMLATorchairDecodeMetadata] = None
prefill: Optional[AscendMLATorchairPrefillMetadata] = None
enable_dbo_across_dp: bool = False
def __post_init__(self):
pass
@@ -152,17 +148,6 @@ class AscendMLATorchairMetadata:
# f"Only {supported_head_sizes} are supported for head_dim,",
# f"received {self.head_dim}.")
def split_metadata_for_multistream(
self,
ms_split_config: MSAttentionMetadataSplitConfig,
) -> list["AscendMLATorchairMetadata"]:
"""Split metadata for multi-stream with AscendMLATorchairMetadata"""
return model_input_split_v1_mla_attn(
ms_split_config=ms_split_config,
attn_metadata=self,
_metadata_cls=AscendMLATorchairMetadata,
)
M = TypeVar("M", bound=AscendMLATorchairMetadata)
@@ -576,7 +561,6 @@ class AscendMLATorchairMetadataBuilder:
query_start_loc=query_start_loc,
block_tables=block_table,
seq_lens=seq_lens,
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
)
def pad_actual_seq_len_q(self, num_reqs_pad_size, num_reqs,
@@ -1072,15 +1056,8 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
context_lens=attn_metadata.decode.seq_lens, # type:ignore
mla_vheadsize=self.kv_lora_rank,
out=attn_output)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is None:
return self._v_up_proj_and_o_proj(attn_output,
enable_multistream_mla)
else:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
current_ms_metadata.before_comm_event.wait()
return self._v_up_proj_and_o_proj(attn_output)
return self._v_up_proj_and_o_proj(attn_output, enable_multistream_mla)
def forward(
self,
@@ -1248,14 +1225,7 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
prefill_k_c_normed,
prefill_k_pe, kv_cache,
attn_metadata)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is not None:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
current_ms_metadata.before_comm_event.wait()
o_proj_input[num_decode_tokens:] = output_prefill
else:
o_proj_input[num_decode_tokens:] = output_prefill
o_proj_input[num_decode_tokens:] = output_prefill
if has_decode:
if self.running_in_graph:
@@ -1269,35 +1239,19 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
decode_k_nope,
decode_k_pe, kv_cache,
attn_metadata)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is not None:
with torch.npu.stream(current_ms_metadata.comm_stream):
o_proj_input[:num_decode_tokens] = output_decode
else:
o_proj_input[:num_decode_tokens] = output_decode
o_proj_input[:num_decode_tokens] = output_decode
current_ms_metadata = get_multistream_comm_context()
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
if current_ms_metadata is None:
maybe_npu_prefetch(self.o_proj.weight,
o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=enable_multistream_mla)
output[...] = self.o_proj(
o_proj_input,
is_prefill=True,
is_force_scatter=self.enable_shared_expert_dp)[0]
else:
with torch.npu.stream(current_ms_metadata.comm_stream):
maybe_npu_prefetch(self.o_proj.weight,
o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=enable_multistream_mla)
output[...] = self.o_proj(
o_proj_input,
is_prefill=True,
is_force_scatter=self.enable_shared_expert_dp)[0]
current_ms_metadata.after_comm_event.record()
maybe_npu_prefetch(self.o_proj.weight,
o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=enable_multistream_mla)
output[...] = self.o_proj(
o_proj_input,
is_prefill=True,
is_force_scatter=self.enable_shared_expert_dp)[0]
del o_proj_input
return output_padded

View File

@@ -110,30 +110,28 @@ class NPUTorchairModelRunner(NPUModelRunner):
self.mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
def _sync_metadata_across_dp(
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
self, num_tokens: int,
with_prefill: bool) -> tuple[int, Optional[torch.Tensor], bool]:
"""Override from NPUModelRunner to pad num_tokens"""
if self.enable_shared_expert_dp:
# Padding is not required for shared_expert_dp cases in eager mode.
return num_tokens, None, with_prefill, enable_dbo
return num_tokens, None, with_prefill
if self.dp_size == 1:
if not with_prefill:
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
num_tokens)
return maybe_padded_num_tokens, None, with_prefill, enable_dbo
return num_tokens, None, with_prefill, enable_dbo
return maybe_padded_num_tokens, None, with_prefill
return num_tokens, None, with_prefill
num_tokens_across_dp = torch.zeros(self.dp_size + 2,
num_tokens_across_dp = torch.zeros(self.dp_size + 1,
dtype=torch.int32,
device="npu")
num_tokens_across_dp[self.dp_rank] = num_tokens
num_tokens_across_dp[-2] = int(with_prefill)
num_tokens_across_dp[-1] = int(not enable_dbo)
num_tokens_across_dp[-1] = int(with_prefill)
dist.all_reduce(num_tokens_across_dp,
group=get_dp_group().device_group)
with_prefill = bool(num_tokens_across_dp[-2])
enable_dbo = not bool(num_tokens_across_dp[-1])
num_tokens_across_dp = num_tokens_across_dp[:-2]
with_prefill = bool(num_tokens_across_dp[-1])
num_tokens_across_dp = num_tokens_across_dp[:-1]
if not with_prefill:
max_num_token = num_tokens_across_dp.max().item()
@@ -146,7 +144,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
else:
maybe_padded_num_tokens = num_tokens
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill
def _build_dummy_attn_metadata(
self,

View File

@@ -21,8 +21,6 @@ from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
split_decodes_and_prefills)
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
from vllm_ascend.utils import is_enable_nz
from vllm_ascend.worker.npu_input_batch import InputBatch
@@ -141,7 +139,6 @@ class AscendSFATorchairMetadata:
decode: Optional[AscendSFATorchairDecodeMetadata] = None
prefill: Optional[AscendSFATorchairPrefillMetadata] = None
enable_dbo_across_dp: bool = False
is_prefill: bool = False
is_decode: bool = False
@@ -154,17 +151,6 @@ class AscendSFATorchairMetadata:
# f"Only {supported_head_sizes} are supported for head_dim,",
# f"received {self.head_dim}.")
def split_metadata_for_multistream(
self,
ms_split_config: MSAttentionMetadataSplitConfig,
) -> list["AscendSFATorchairMetadata"]:
"""Split metadata for multi-stream with AscendSFATorchairMetadata"""
return model_input_split_v1_mla_attn(
ms_split_config=ms_split_config,
attn_metadata=self,
_metadata_cls=AscendSFATorchairMetadata,
)
M = TypeVar("M", bound=AscendSFATorchairMetadata)
@@ -616,7 +602,6 @@ class AscendSFATorchairMetadataBuilder:
query_start_loc=query_start_loc,
block_tables=block_table,
seq_lens=seq_lens,
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
is_prefill=is_prefill,
is_decode=is_decode)

View File

@@ -758,13 +758,13 @@ def get_default_buffer_config() -> dict:
def calculate_dp_buffer_size() -> int:
"""
formula of dp buffer size:
dp_size + 2 (flags: with_prefill and enable_dbo)
dp_size + 1 (flags: with_prefill)
"""
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
dp_size = vllm_config.parallel_config.data_parallel_size
int32_size = torch.iinfo(torch.int32).bits // 8
dp_buffer_size = math.ceil((dp_size + 2) * int32_size / (1024 * 1024))
dp_buffer_size = math.ceil((dp_size + 1) * int32_size / (1024 * 1024))
return max(dp_buffer_size, _MIN_DP_BUFFER_SIZE)

View File

@@ -121,7 +121,6 @@ from vllm_ascend.eplb.core.eplb_utils import EPLBParamUtils
from vllm_ascend.eplb.core.eplb_worker import EplbProcess
from vllm_ascend.eplb.eplb_updator import EplbUpdator
from vllm_ascend.eplb.utils import model_register
from vllm_ascend.multistream.ms_split import compute_split_seq_index
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.sample.logits_processor import build_logitsprocs
@@ -859,8 +858,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
)
def _sync_metadata_across_dp(
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
self, num_tokens: int,
with_prefill: bool) -> tuple[int, Optional[torch.Tensor], bool]:
# TODO: In vLLM, the only thing that needs to be synced is num_tokens, but in
# our case, we still need to sync the other two flags as well. So we need to
# include them in the all_reduce operation, and more over, we CANNOT skip it
@@ -868,31 +867,29 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# FIXME: Restore the `or self.vllm_config.model_config.enforce_eager` here
# immediately once the other two flags are no longer needed.
if self.dp_size == 1:
return num_tokens, None, with_prefill, enable_dbo
return num_tokens, None, with_prefill
# Sync num_tokens, with_prefill, enable_dbo across dp ranks
# Sync num_tokens, with_prefill across dp ranks
num_tokens_tensor = torch.tensor([
num_tokens if i == self.dp_rank else 0 for i in range(self.dp_size)
],
dtype=torch.int32,
device="npu")
flags_tensor = torch.tensor(
[int(with_prefill), int(not enable_dbo)],
dtype=torch.int32,
device="npu")
flags_tensor = torch.tensor([int(with_prefill)],
dtype=torch.int32,
device="npu")
packed_tensor = torch.cat([num_tokens_tensor, flags_tensor])
dist.all_reduce(packed_tensor, group=get_dp_group().device_group)
# Unpack the results
num_tokens_across_dp = packed_tensor[:-2]
synced_flags = packed_tensor[-2:]
num_tokens_across_dp = packed_tensor[:-1]
synced_flags = packed_tensor[-1:]
max_tokens_across_dp = torch.max(num_tokens_across_dp).item()
global_with_prefill = bool(synced_flags[0])
global_enable_dbo = not bool(synced_flags[1])
# Create a tensor for num_tokens_after_padding
num_tokens_after_padding = torch.tensor([max_tokens_across_dp] *
@@ -900,28 +897,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
device="cpu",
dtype=torch.int32)
return max_tokens_across_dp, num_tokens_after_padding, global_with_prefill, global_enable_dbo
def _check_dbo_is_valid(self, query_lens: torch.Tensor,
attn_state: AscendAttentionState,
num_tokens: int) -> bool:
# do the checks for dp + dbo
if attn_state in [
AscendAttentionState.DecodeOnly,
AscendAttentionState.SpecDecoding
]:
return False
# considering the case that one dp rank may enable dbo while others may not
if not self.vllm_config.model_config.use_mla or not envs_ascend.VLLM_ASCEND_ENABLE_DBO:
return False
# TODO: remove it if token-level microbatch is enabled
[token_index,
seq_index] = compute_split_seq_index(query_lens, attn_state,
num_tokens)
if token_index == 0 or seq_index == 0 or seq_index == len(
query_lens) or num_tokens < 256:
return False
return True
return max_tokens_across_dp, num_tokens_after_padding, global_with_prefill
def get_model(self) -> nn.Module:
# get raw model out of the aclgraph wrapper.
@@ -1430,16 +1406,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
]
self.query_lens = torch.from_numpy(num_scheduled_tokens)
enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(),
attn_state,
total_num_scheduled_tokens)
# Get info across DP ranks.
# NOTE: maybe_padded_num_tokens is only used when using TorchAir with DP,
# Otherwise, it's just max_tokens_across_dp_cpu
(maybe_padded_num_tokens, num_tokens_across_dp, with_prefill,
enable_dbo) = self._sync_metadata_across_dp(num_input_tokens,
with_prefill, enable_dbo)
(maybe_padded_num_tokens, num_tokens_across_dp,
with_prefill) = self._sync_metadata_across_dp(num_input_tokens,
with_prefill)
# TODO: Now that num_input_tokens is basically identical with maybe_padded_num_tokens
# We should consider removing maybe_padded_num_tokens later
@@ -1707,7 +1680,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
attn_mask=self.attn_mask,
spec_attn_mask=self.spec_attn_mask,
attn_state=self.attn_state,
enable_dbo_across_dp=enable_dbo,
is_only_prefill=bool(np.all(num_valid_tokens != 1)),
max_query_len=max_num_scheduled_tokens,
graph_pad_size=self.graph_pad_size,
@@ -2603,8 +2575,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_tokens = math.ceil(num_tokens / tp_size) * tp_size
# Padding for DP
(num_tokens, num_tokens_across_dp, with_prefill,
_) = self._sync_metadata_across_dp(num_tokens, with_prefill, False)
(num_tokens, num_tokens_across_dp,
with_prefill) = self._sync_metadata_across_dp(num_tokens,
with_prefill)
moe_comm_type = self._select_moe_comm_method(num_tokens, with_prefill)