[perf]: support dual-batch overlap(dbo) for deepseek (#941)

### What this PR does / why we need it?
Based on the design of dual-batch overlap proposed by Deepseek team and
also the implementation of fused moe in VLLM project, we implement the
multi-stream(also known as dual-batch) overlap for deepseek+mla on
Ascend NPU. We split the input batch of model into two microbatches and
then overlap the comp/comm ops in attention and moe layers using two
streams to improve the performance. Our approach can be easily extended
when adding dispatch/combine communications for moe layer.
Compared with the previously proposed
[draft](https://github.com/vllm-project/vllm-ascend/pull/842), we use
one stream for computation ops and the other for communication ops,
separately. In out opinions, it is beneficial for arranging the order of
executing different ops and thus avoiding the contention of
computation/communication resources.

ref: [overlap for
llama](https://github.com/vllm-project/vllm/pull/15787/files)
ref: [dbo in
sglang](https://github.com/sgl-project/sglang/pull/4068/files#diff-b4937569fc71f6ad215181b633b2f89c7183a2b4ac39e41fc22635599a9be7de)

### Does this PR introduce _any_ user-facing change?
Adding an env variable "VLLM_ENABLE_DBO". Users can enable dbo by
setting "VLLM_ASCEND_ENABLE_DBO=1"
See /examples/offline_dualbatch_overlap_npu.py for more info.

### How was this patch tested?

This patch can be tested with vllm-0.9.0 using its online service with
benchmark tests. We have decoupled the func of dbo from vllm and it
should be able to run without any modification to the code of vllm(some
modifications is better to implement in vllm though).



Any advice/discussion is welcome.

### Performance Benchmark

We have ran the benchmark_serving script of vllm to test the performance
after using dual-batch overlap.

`python -m vllm.entrypoints.openai.api_server \
 --model=DeepSeek-R1-W8A8 \
 --trust-remote-code \
 --distributed-executor-backend=mp \
 -tp=16 \
 --port 8006 \
 --max-num-seqs 390 \
 --max-model-len 32768 \
 --max-num-batched-tokens 65536 \
 --block-size 128 \
 --compilation_config 0 \
 --gpu-memory-utilization 0.90 \
 --disable-log-requests \
--additional-config
'{"expert_tensor_parallel_size":1,"enable_inter_dp_scheduling":true,"init_torchair_graph_batch_sizes":true,"trace_recompiles":true,"ascend_scheduler_config":{},"enable_graph_mode":false}'`

and run benchmark with the parameters of :
`--dataset-name random --random-input-len 4096 --random-output-len 1
--num-prompts 200 --max-concurrency 8 --request-rate 5
--metric-percentiles 90`

1. test with the version using allgather+allreduce in Ascend 910B (tp16
ep16 + deepseek r1 w8a8)

2. test with the version using alltoall: 

prefill qps: 0.90 -> 1.01
Mean TTFT:8226->7432ms

The overlap approach when using alltoall communication can be further
optimized by overlapping micro-batch1's moe comp with micro-batch2's
dispatch a2a comm

---------

Signed-off-by: zhuohuan <zxdu1997@gmail.com>
This commit is contained in:
zxdukki
2025-06-07 16:46:58 +08:00
committed by GitHub
parent 3640c60b0e
commit 87ebaef4e4
14 changed files with 1896 additions and 11 deletions

View File

@@ -0,0 +1,51 @@
import os
import time
from vllm import LLM, SamplingParams
# enable dual-batch overlap for vllm ascend
os.environ["VLLM_ASCEND_ENABLE_DBO"] = "1"
os.environ["VLLM_USE_V1"] = "1"
# Sample prompts.
prompts = ["The president of the United States is"] * 41
# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
def main():
# Create an LLM.
llm = LLM(model="deepseek-ai/DeepSeek-V3-Lite-base-latest-w8a8-dynamic",
enforce_eager=True,
tensor_parallel_size=2,
max_model_len=4096,
trust_remote_code=True,
additional_config={
"torchair_graph_config": {
"enabled": False
},
"ascend_scheduler_config": {
"enabled": True
},
"expert_tensor_parallel_size": 1
})
# Generate texts from the prompts. The output is a list of RequestOutput
# objects that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
# Add a buffer to wait for profiler in the background process
# (in case MP is on) to finish writing profiling output.
time.sleep(10)
if __name__ == "__main__":
main()

View File

@@ -81,3 +81,17 @@ def test_models_distributed_topk() -> None:
distributed_executor_backend="mp", distributed_executor_backend="mp",
) as vllm_model: ) as vllm_model:
vllm_model.generate(example_prompts, sampling_params) vllm_model.generate(example_prompts, sampling_params)
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DBO": "1"})
def test_models_distributed_DeepSeek_dbo():
example_prompts = ["The president of the United States is"] * 41
dtype = "half"
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
with VllmRunner(
"deepseek-ai/DeepSeek-V2-Lite",
dtype=dtype,
tensor_parallel_size=4,
distributed_executor_backend="mp",
) as vllm_model:
vllm_model.generate(example_prompts, sampling_params)

View File

@@ -13,6 +13,9 @@ from vllm.model_executor.layers.linear import (LinearBase,
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.context import get_multistream_comm_context
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -117,6 +120,7 @@ class AscendMLAMetadata:
with_prefill_across_dp: bool = False with_prefill_across_dp: bool = False
query_lens: Optional[list[int]] = None
# The dimension of the attention heads # The dimension of the attention heads
head_dim: Optional[int] = None head_dim: Optional[int] = None
attn_mask: torch.Tensor = None attn_mask: torch.Tensor = None
@@ -135,6 +139,17 @@ class AscendMLAMetadata:
# f"Only {supported_head_sizes} are supported for head_dim,", # f"Only {supported_head_sizes} are supported for head_dim,",
# f"received {self.head_dim}.") # f"received {self.head_dim}.")
def split_metadata_for_multistream(
self,
ms_split_config: MSAttentionMetadataSplitConfig,
) -> list["AscendMLAMetadata"]:
"""Split metadata for multi-stream with AscendMLAMetadata"""
return model_input_split_v1_mla_attn(
ms_split_config=ms_split_config,
attn_metadata=self,
_metadata_cls=AscendMLAMetadata,
)
M = TypeVar("M", bound=AscendMLAMetadata) M = TypeVar("M", bound=AscendMLAMetadata)
@@ -386,6 +401,7 @@ class AscendMLAMetadataBuilder:
return self.metadata_cls( # type: ignore return self.metadata_cls( # type: ignore
num_actual_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens,
query_lens=query_lens.tolist(),
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
head_dim=self.runner.model_config.get_head_size(), head_dim=self.runner.model_config.get_head_size(),
num_decodes=self._num_decodes, num_decodes=self._num_decodes,
@@ -585,7 +601,15 @@ class AscendMLAImpl(MLAAttentionImpl):
) )
attn_output = attn_output.reshape( attn_output = attn_output.reshape(
[num_tokens, self.num_heads * self.v_head_dim]) [num_tokens, self.num_heads * self.v_head_dim])
return self.o_proj(attn_output)[0]
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is None:
return self.o_proj(attn_output)[0]
else:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
current_ms_metadata.before_comm_event.wait()
return self.o_proj(attn_output)[0]
def exec_kv( def exec_kv(
self, self,
@@ -685,7 +709,14 @@ class AscendMLAImpl(MLAAttentionImpl):
context_lens=attn_metadata.decode.seq_lens, # type:ignore context_lens=attn_metadata.decode.seq_lens, # type:ignore
mla_vheadsize=self.kv_lora_rank, mla_vheadsize=self.kv_lora_rank,
out=attn_output) out=attn_output)
return self._v_up_proj_and_o_proj(attn_output) current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is None:
return self._v_up_proj_and_o_proj(attn_output)
else:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
current_ms_metadata.before_comm_event.wait()
return self._v_up_proj_and_o_proj(attn_output)
def forward( def forward(
self, self,
@@ -811,16 +842,38 @@ class AscendMLAImpl(MLAAttentionImpl):
key_cache=kv_cache, key_cache=kv_cache,
slot_indices=attn_metadata.slot_mapping.flatten()) slot_indices=attn_metadata.slot_mapping.flatten())
if has_prefill: if has_prefill:
output[num_decode_tokens:] = self._forward_prefill( # FIX: aicore move should be also placed on the comm stream in dbo,
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, # otherwise it may affect the accuracy
attn_metadata) # TODO: use an elegant way to overlap
output_prefill = self._forward_prefill(prefill_q,
prefill_k_c_normed,
prefill_k_pe, kv_cache,
attn_metadata)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is not None:
with torch.npu.stream(current_ms_metadata.comm_stream):
output[num_decode_tokens:] = output_prefill
current_ms_metadata.after_comm_event.record()
else:
output[num_decode_tokens:] = output_prefill
if has_decode: if has_decode:
if self.running_in_graph: if self.running_in_graph:
return self._forward_decode(decode_ql_nope, decode_q_pe, return self._forward_decode(decode_ql_nope, decode_q_pe,
decode_k_nope, decode_k_pe, decode_k_nope, decode_k_pe,
kv_cache, attn_metadata) kv_cache, attn_metadata)
else: else:
output[:num_decode_tokens] = self._forward_decode( output_decode = self._forward_decode(decode_ql_nope,
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe, decode_q_pe,
kv_cache, attn_metadata) decode_k_nope,
decode_k_pe, kv_cache,
attn_metadata)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is not None:
with torch.npu.stream(current_ms_metadata.comm_stream):
output[:num_decode_tokens] = output_decode
current_ms_metadata.after_comm_event.record()
else:
output[:num_decode_tokens] = output_decode
return output_padded return output_padded

View File

@@ -107,6 +107,8 @@ env_variables: Dict[str, Callable[[], Any]] = {
# Whether to enable the trace recompiles from pytorch. # Whether to enable the trace recompiles from pytorch.
"VLLM_ASCEND_TRACE_RECOMPILES": "VLLM_ASCEND_TRACE_RECOMPILES":
lambda: bool(int(os.getenv("VLLM_ASCEND_TRACE_RECOMPILES", '0'))), lambda: bool(int(os.getenv("VLLM_ASCEND_TRACE_RECOMPILES", '0'))),
"VLLM_ASCEND_ENABLE_DBO":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_DBO", '0'))),
# Whether to enable the model execute time observe profile. Disable it when # Whether to enable the model execute time observe profile. Disable it when
# running vllm ascend in production environment. # running vllm ascend in production environment.
"VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE": "VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE":

View File

@@ -1,7 +1,10 @@
from vllm import ModelRegistry from vllm import ModelRegistry
import vllm_ascend.envs as envs
def register_model(): def register_model():
from .deepseek_dbo import CustomDeepseekDBOForCausalLM # noqa: F401
from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401 from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401
from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401 from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401
from .deepseek_v2 import CustomDeepseekV3ForCausalLM # noqa: F401 from .deepseek_v2 import CustomDeepseekV3ForCausalLM # noqa: F401
@@ -22,9 +25,14 @@ def register_model():
"vllm_ascend.models.qwen2_5_vl:AscendQwen2_5_VLForConditionalGeneration" "vllm_ascend.models.qwen2_5_vl:AscendQwen2_5_VLForConditionalGeneration"
) )
ModelRegistry.register_model( if envs.VLLM_ASCEND_ENABLE_DBO:
"DeepseekV2ForCausalLM", ModelRegistry.register_model(
"vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM") "DeepseekV2ForCausalLM",
"vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM")
else:
ModelRegistry.register_model(
"DeepseekV2ForCausalLM",
"vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM")
ModelRegistry.register_model( ModelRegistry.register_model(
"DeepseekV3ForCausalLM", "DeepseekV3ForCausalLM",

File diff suppressed because it is too large Load Diff

View File

View File

@@ -0,0 +1,29 @@
from dataclasses import dataclass
from enum import Enum
class MSEventKey(Enum):
ATTN_COM_FINISH = 0
ATTN_AR_FINISH = 1
FFN_COM_FINISH = 2
FFN_AR_FINISH = 3
# events for MOE dispatch and combine
MOE_BEFORE_COMM = 4
MOE_AFTER_COMM = 5
# events for shared expert
MOE_SE_COMM_FINISH = 6
MOE_SE_COMP_FINISH = 7
MOE_GATE_FINISH = 8
@dataclass
class MSAttentionMetadataSplitConfig:
"""
micro batch split config for split attention metadata
"""
# micro batch num
num_micro_batches: int = 2
# split micro batches only when total tokens >= min_total_tokens_to_split
min_total_tokens_to_split: int = 256
# split micro batches only when prefill tokens >= min_prefill_tokens_to_split
min_prefill_tokens_to_split: int = 64

View File

@@ -0,0 +1,67 @@
from contextlib import contextmanager
from typing import Any
_ms_comm_context: Any = None
_cur_micro_batch_num: int = -1
_ms_layer_index_context: int = -1
_ms_metadata_context: Any = None
_ms_attn_metadata_context: Any = None
def set_multistream_layer_context(start_layer: int, ms_metadata: Any,
attn_metadata: Any):
"""
set multistream layer context before transformer layers
"""
global _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context
_ms_layer_index_context = start_layer
_ms_metadata_context = ms_metadata
_ms_attn_metadata_context = attn_metadata
def reset_multistream_layer_context():
"""
reset multistream layer context
"""
global _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context
_ms_layer_index_context = -1
_ms_metadata_context = None
_ms_attn_metadata_context = None
def get_multistream_layer_context():
"""
get multistream layer context
"""
return _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context
def advance_step_multistream_layer_context():
"""
advance multistream layer index context
"""
global _ms_layer_index_context
_ms_layer_index_context += 1
def get_multistream_comm_context() -> Any:
"""Get the current comm forward context."""
return _ms_comm_context
def get_multistream_microbatch_context() -> int:
return _cur_micro_batch_num
@contextmanager
def set_multistream_context(context: Any, micro_batch_num: int):
"""A context manager that stores the current comm forward context,
can be attention metadata, etc."""
global _ms_comm_context, _cur_micro_batch_num
_ms_comm_context = context
_cur_micro_batch_num = micro_batch_num
try:
yield
finally:
_ms_comm_context = None
_cur_micro_batch_num = -1

View File

@@ -0,0 +1,26 @@
from vllm.logger import init_logger
from .context import (get_multistream_layer_context,
get_multistream_microbatch_context)
logger = init_logger(__name__)
# vllm v1 use get_forward_context to get the attn_metadata,
# we can use this decorator to update the attn metadata
def set_multistream_support():
def decorator(func):
def wrapper():
context = func()
layer_index, ms_metadata, attn_metadata = get_multistream_layer_context(
)
micro_batch_num = get_multistream_microbatch_context()
if layer_index != -1 and micro_batch_num != -1:
context.attn_metadata = attn_metadata[micro_batch_num]
return context
return wrapper
return decorator

View File

@@ -0,0 +1,61 @@
from typing import List, Optional, Tuple, Union
import torch
from vllm.forward_context import get_forward_context
from .base import MSEventKey
from .context import (get_multistream_layer_context,
reset_multistream_layer_context,
set_multistream_layer_context)
from .metadata import MultiStreamMetadata
class MultiStreamPreTransformerLayer(torch.nn.Module):
def __init__(self, multistream_metadata: MultiStreamMetadata):
super().__init__()
self.multistream_metadata = multistream_metadata
def forward(
self,
intput_tensors: List[torch.Tensor],
):
attn_metadata = get_forward_context().attn_metadata
if self.multistream_metadata is None or attn_metadata is None:
set_multistream_layer_context(-1, None, None)
return attn_metadata, intput_tensors
# TODO add attn_metadata management
do_ms, attn_metadata, intput_tensors, _ = self.multistream_metadata.split_micro_batch(
attn_metadata, intput_tensors)
if do_ms:
set_multistream_layer_context(
self.multistream_metadata.start_layer,
self.multistream_metadata, attn_metadata)
else:
set_multistream_layer_context(-1, None, None)
return attn_metadata, intput_tensors
class MultiStreamPostTransformerLayer(torch.nn.Module):
def __init__(self, multistream_metadata: MultiStreamMetadata):
super().__init__()
self.multistream_metadata = multistream_metadata
def forward(self,
input_tensors: Union[List[Tuple[torch.Tensor]],
List[torch.Tensor],
List[List[torch.Tensor]]],
wait_layer_index: Optional[int] = None):
if self.multistream_metadata is None or self.multistream_metadata.ms_config is None:
return input_tensors
layer_index, ms_metadata, ms_attn_metadata = get_multistream_layer_context(
)
if layer_index >= 0:
true_wait_layer = self.multistream_metadata.end_layer - 1 if wait_layer_index is None else wait_layer_index
self.multistream_metadata.try_wait_event(
true_wait_layer,
self.multistream_metadata.ms_config.num_micro_batches - 1,
MSEventKey.FFN_AR_FINISH)
reset_multistream_layer_context()
return self.multistream_metadata.merge_micro_batches(input_tensors)

View File

@@ -0,0 +1,182 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
import torch
from vllm.sequence import IntermediateTensors
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
from .base import MSAttentionMetadataSplitConfig, MSEventKey
def split_micro_batches_tensors(input_tensors,
split_index: int,
keys: Optional[List[str]] = None):
if isinstance(input_tensors, list):
micro_batches = []
for tensor in input_tensors:
if tensor is None:
micro_batches.append([None, None])
else:
micro_batches.append(
[tensor[:split_index], tensor[split_index:]])
return micro_batches
elif isinstance(input_tensors, torch.Tensor):
return [input_tensors[:split_index], input_tensors[split_index:]]
elif input_tensors is None:
return [None, None]
elif isinstance(input_tensors, Dict):
assert keys is not None
micro_batches_pre = {}
for key in keys:
micro_batches_pre[key] = input_tensors[key][:split_index]
micro_batches_post = {}
for key in keys:
micro_batches_post[key] = input_tensors[key][split_index:]
return [micro_batches_pre, micro_batches_post]
else:
raise NotImplementedError
@dataclass
class MultiStreamStepMetadata:
comm_stream: torch.npu.Stream = None
before_comm_event: torch.npu.Event = None
after_comm_event: torch.npu.Event = None
@dataclass
class MultiStreamConfig:
"""Controls the behavior of multi-stream models."""
min_total_tokens_to_split: int = 256
min_prefill_tokens_to_split: int = 64
num_micro_batches: int = 2
imbalance_ratio: float = 0.1
class MultiStreamMetadata:
# direct stream
calculate_stream = None
# delay stream
communicate_stream = None
# events
ms_events: Dict[int, Dict[int, Dict[MSEventKey, torch.npu.Event]]] = {}
# multi-stream-flag
enable_multi_stream: bool = False
def __init__(
self,
calculate_stream: torch.npu.Stream,
communicate_stream: torch.npu.Stream,
start_layer: int,
end_layer: int,
event_keys: List[MSEventKey],
multistream_config: Optional[MultiStreamConfig],
causal_lm: bool = True,
):
self.calculate_stream = calculate_stream
self.communicate_stream = communicate_stream
self.start_layer = start_layer
self.end_layer = end_layer
self.ms_config = multistream_config
self.causal_lm = causal_lm
self._build_events(event_keys)
self._build_ms_split_config()
def _build_events(self, event_keys):
if self.ms_config is not None:
for i in range(self.start_layer - 1, self.end_layer):
self.ms_events[i] = {}
for j in range(self.ms_config.num_micro_batches):
self.ms_events[i][j] = {}
for key in event_keys:
self.ms_events[i][j][key] = torch.npu.Event()
def _build_ms_split_config(self):
if self.ms_config is not None:
self.ms_split_config = MSAttentionMetadataSplitConfig(
num_micro_batches=self.ms_config.num_micro_batches,
min_total_tokens_to_split=self.ms_config.
min_total_tokens_to_split,
min_prefill_tokens_to_split=self.ms_config.
min_prefill_tokens_to_split,
)
def try_wait_event(self, layer_index: int, micro_batch_index: int,
event_key: MSEventKey):
self.ms_events[layer_index][micro_batch_index][event_key].wait()
def try_record_event(self, layer_index: int, micro_batch_index: int,
event_key: MSEventKey):
self.ms_events[layer_index][micro_batch_index][event_key].record()
def split_micro_batch(
self,
attn_metadata: "AscendMLAMetadata",
intput_tensors: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
intermediate_tensors_keys: Optional[List[str]] = None,
) -> Tuple[bool, Union[AscendMLAMetadata, List[AscendMLAMetadata]], Union[
List[torch.Tensor], List[List[torch.Tensor]]], Union[
IntermediateTensors, List[IntermediateTensors]]]:
attn_metadata_list = attn_metadata.split_metadata_for_multistream(
self.ms_split_config)
if len(attn_metadata_list) == 1:
return False, attn_metadata_list[
0], intput_tensors, intermediate_tensors
split_index = attn_metadata_list[0].slot_mapping.shape[0]
input_tensors = split_micro_batches_tensors(intput_tensors,
split_index)
if intermediate_tensors is not None:
inter_tensors_list = split_micro_batches_tensors(
intermediate_tensors.tensors, split_index,
intermediate_tensors_keys)
intermediate_tensors = [
IntermediateTensors(inter_tensors)
for inter_tensors in inter_tensors_list
]
return True, attn_metadata_list, input_tensors, intermediate_tensors
def merge_micro_batches(
self, input_tensors: Union[List[torch.Tensor],
List[List[torch.Tensor]]]
) -> List[torch.Tensor]:
if input_tensors is None or isinstance(input_tensors[0], torch.Tensor):
return input_tensors
batch: List[Optional[torch.Tensor]] = []
for tensors in input_tensors:
if tensors is None or tensors[0] is None:
batch.append(None)
else:
batch.append(torch.cat(tensors, dim=0))
return batch
def make_multistream_metadata_ds(
start_layer: int,
end_layer: int,
causal_lm: bool = True,
multistream_config: Optional[MultiStreamConfig] = None,
):
if multistream_config is None:
return None
event_keylist = [
MSEventKey.ATTN_COM_FINISH,
MSEventKey.ATTN_AR_FINISH,
MSEventKey.FFN_COM_FINISH,
MSEventKey.FFN_AR_FINISH,
MSEventKey.MOE_BEFORE_COMM,
MSEventKey.MOE_AFTER_COMM,
MSEventKey.MOE_SE_COMM_FINISH,
MSEventKey.MOE_SE_COMP_FINISH,
MSEventKey.MOE_GATE_FINISH,
]
return MultiStreamMetadata(
calculate_stream=torch.npu.current_stream(),
communicate_stream=torch.npu.Stream(),
start_layer=start_layer,
end_layer=end_layer,
multistream_config=multistream_config,
event_keys=event_keylist,
causal_lm=causal_lm,
)

View File

@@ -0,0 +1,245 @@
from copy import deepcopy
from typing import Any, List, Optional
import numpy as np
import torch
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from .base import MSAttentionMetadataSplitConfig
def compute_split_seq_index(
query_lens: Optional[list[int]],
attn_state: AscendAttentionState,
num_tokens: int,
imbalance_ratio: float = 0.1,
) -> list[int]:
if attn_state != AscendAttentionState.DecodeOnly:
assert query_lens is not None
total_tokens = sum(query_lens)
# the first index in last split
tokens, split_index = 0, 0
for value in query_lens:
tokens += value
split_index += 1
if tokens >= total_tokens // 2:
# check the current split index
if abs(tokens -
total_tokens // 2) < total_tokens * imbalance_ratio:
return [tokens, split_index]
# check the previous split index
elif abs(tokens - total_tokens // 2 -
value) < total_tokens * imbalance_ratio:
return [tokens - value, split_index - 1]
# fail to split if it is imbalanced
# TODO: split tokens in seq
else:
return [0, 0]
else:
tokens = num_tokens // 2
return [tokens, tokens]
return [0, 0]
def split_attn_tensor_type(
input_tensor: torch.Tensor,
index: int,
) -> List[torch.Tensor]:
return [input_tensor[:index], input_tensor[index:]]
def split_attn_int_type(
var: int,
index: int,
) -> List[torch.Tensor]:
return [min(var, index), max(var - index, 0)]
def model_input_split_v1_mla_attn(
attn_metadata,
_metadata_cls,
ms_split_config: MSAttentionMetadataSplitConfig,
) -> List[Any]:
assert 0 < ms_split_config.num_micro_batches < 3
if attn_metadata is None:
return [attn_metadata]
[token_index,
seq_index] = compute_split_seq_index(attn_metadata.query_lens,
attn_metadata.attn_state,
attn_metadata.num_decode_tokens)
if token_index == 0 or seq_index == 0 or seq_index == len(
attn_metadata.query_lens):
return [attn_metadata]
query_start_loc_cpu = np.zeros(shape=(len(attn_metadata.query_lens) + 1, ),
dtype=int)
np.cumsum(attn_metadata.query_lens, out=query_start_loc_cpu[1:])
if attn_metadata.num_prefills > 0:
prefill_query_start_loc = np.zeros(
shape=(len(attn_metadata.prefill.query_lens) + 1, ), dtype=int)
np.cumsum(attn_metadata.prefill.query_lens,
out=prefill_query_start_loc[1:])
# split attn metadata
[slot_mapping_pre,
slot_mapping_post] = split_attn_tensor_type(attn_metadata.slot_mapping,
token_index)
[num_decodes_pre,
num_decodes_post] = split_attn_int_type(attn_metadata.num_decodes,
seq_index)
[num_decode_tokens_pre, num_decode_tokens_post
] = split_attn_int_type(attn_metadata.num_decode_tokens, token_index)
[num_prefills_pre, num_prefills_post
] = split_attn_int_type(attn_metadata.num_prefills,
max(0, seq_index - attn_metadata.num_decodes))
seq_lens = attn_metadata.prefill.seq_lens if attn_metadata.num_prefills > 0 else attn_metadata.decode.seq_lens
[seq_lens_pre, seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index)
query_start_loc_pre = attn_metadata.query_start_loc[:seq_index + 1]
query_start_loc_post = deepcopy(
attn_metadata.query_start_loc[seq_index:]
) - attn_metadata.query_start_loc[seq_index]
[block_table_pre,
block_table_post] = split_attn_tensor_type(attn_metadata.block_tables,
seq_index)
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache or attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
# the attn_mla kernel in torch npu only accept 128*128 attn mask
attn_mask_pre = attn_mask_post = attn_metadata.attn_mask
attn_state_pre = attn_state_post = attn_metadata.attn_state
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
# should be none in decode only state
attn_mask_pre = attn_mask_post = attn_metadata.attn_mask
attn_state_pre = attn_state_post = AscendAttentionState.DecodeOnly
else:
# chunked prefill
if num_prefills_pre > 0:
attn_state_pre = attn_state_post = AscendAttentionState.ChunkedPrefill
attn_mask_pre = attn_metadata.attn_mask[:token_index, :max(
seq_lens_pre)].contiguous()
attn_state_post = AscendAttentionState.ChunkedPrefill
attn_mask_post = attn_metadata.attn_mask[
token_index:, :max(seq_lens_post)].contiguous()
else:
attn_state_pre = AscendAttentionState.DecodeOnly
attn_mask_pre = None
attn_state_post = AscendAttentionState.ChunkedPrefill
attn_mask_post = attn_metadata.attn_mask[
token_index:, :max(seq_lens_post)].contiguous()
from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata,
AscendMLAPrefillMetadata)
if num_prefills_pre > 0:
# split metadata.prefill
[input_positions_pre, input_positions_post] = split_attn_tensor_type(
attn_metadata.prefill.input_positions,
token_index - attn_metadata.num_decode_tokens)
[block_tables_pre, block_tables_post
] = split_attn_tensor_type(attn_metadata.prefill.block_table,
seq_index - attn_metadata.num_decodes)
[prefill_query_lens_pre, prefill_query_lens_post
] = split_attn_tensor_type(attn_metadata.prefill.query_lens,
seq_index - attn_metadata.num_decodes)
prefill_query_start_loc_pre = attn_metadata.prefill.query_start_loc[:
seq_index
+
1 -
attn_metadata
.
num_decodes]
prefill_query_start_loc_post = deepcopy(
attn_metadata.prefill.query_start_loc[seq_index -
attn_metadata.num_decodes:]
) - attn_metadata.prefill.query_start_loc[seq_index -
attn_metadata.num_decodes]
context_len_pre = seq_lens_pre[attn_metadata.num_decodes:]
context_len_post = seq_lens_post
prefill_max_query_len_pre = max(prefill_query_lens_pre)
prefill_max_query_len_post = max(prefill_query_lens_post)
prefill_pre = AscendMLAPrefillMetadata(
attn_mask=attn_mask_pre,
query_lens=prefill_query_lens_pre,
seq_lens=seq_lens_pre,
query_start_loc=prefill_query_start_loc_pre,
input_positions=input_positions_pre,
context_lens=context_len_pre,
block_table=block_tables_pre,
max_query_len=prefill_max_query_len_pre,
max_seq_lens=context_len_pre.max().item(),
)
prefill_post = AscendMLAPrefillMetadata(
attn_mask=attn_mask_post,
query_lens=prefill_query_lens_post,
seq_lens=seq_lens_post,
query_start_loc=prefill_query_start_loc_post,
input_positions=input_positions_post,
context_lens=context_len_post,
block_table=block_tables_post,
max_query_len=prefill_max_query_len_post,
max_seq_lens=context_len_post.max().item(),
)
decode_pre = attn_metadata.decode
decode_post = None
else:
# prefill is None, split metadata.decode
[input_positions_pre, input_positions_post
] = split_attn_tensor_type(attn_metadata.decode.input_positions,
token_index)
[block_tables_pre, block_tables_post
] = split_attn_tensor_type(attn_metadata.decode.block_table,
seq_index)
[decode_seq_lens_pre,
decode_seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index)
decode_pre = AscendMLADecodeMetadata(
input_positions=input_positions_pre,
block_table=block_tables_pre,
seq_lens=decode_seq_lens_pre,
max_seq_lens=max(decode_seq_lens_pre),
seq_lens_list=decode_seq_lens_pre.tolist(),
)
decode_post = AscendMLADecodeMetadata(
input_positions=input_positions_post,
block_table=block_tables_post,
seq_lens=decode_seq_lens_post,
max_seq_lens=max(decode_seq_lens_post),
seq_lens_list=decode_seq_lens_post.tolist(),
)
prefill_pre = None
prefill_post = attn_metadata.prefill
# construct metadata
from vllm_ascend.attention.mla_v1 import AscendMLAPrefillMetadata
attention_metadata_pre = _metadata_cls(
num_actual_tokens=token_index,
num_input_tokens=token_index,
head_dim=attn_metadata.head_dim,
slot_mapping=slot_mapping_pre,
seq_lens=seq_lens_pre,
query_start_loc=query_start_loc_pre,
block_tables=block_table_pre,
num_decodes=num_decodes_pre,
num_prefills=num_prefills_pre,
num_decode_tokens=num_decode_tokens_pre,
attn_state=attn_state_pre,
attn_mask=attn_mask_pre,
prefill=prefill_pre,
decode=decode_pre,
with_prefill_across_dp=attn_metadata.with_prefill_across_dp,
)
attention_metadata_post = _metadata_cls(
num_actual_tokens=attn_metadata.num_actual_tokens - token_index,
num_input_tokens=attn_metadata.num_input_tokens - token_index,
head_dim=attn_metadata.head_dim,
slot_mapping=slot_mapping_post,
seq_lens=seq_lens_post,
query_start_loc=query_start_loc_post,
block_tables=block_table_post,
num_decodes=num_decodes_post,
num_prefills=num_prefills_post,
num_decode_tokens=num_decode_tokens_post,
attn_mask=attn_mask_post,
attn_state=attn_state_post,
prefill=prefill_post,
decode=decode_post,
with_prefill_across_dp=attn_metadata.with_prefill_across_dp,
)
return [attention_metadata_pre, attention_metadata_post]

View File

@@ -1151,3 +1151,32 @@ class AscendFusedMoE(FusedMoE):
if self.enable_multistream_shared_expert and not is_prefill: if self.enable_multistream_shared_expert and not is_prefill:
return hidden_states, shared_output return hidden_states, shared_output
return hidden_states return hidden_states
# ----------------------------------------- TBO-related --------------------------------------------
def _forward_ms_fused_moe_comp(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_prefill: bool,
real_top_k,
enable_force_load_balance: bool = False,
):
hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=real_top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
is_prefill=is_prefill,
enable_force_load_balance=enable_force_load_balance)
return hidden_states