Files
xc-llm-ascend/vllm_ascend/multistream/context.py
zxdukki 87ebaef4e4 [perf]: support dual-batch overlap(dbo) for deepseek (#941)
### What this PR does / why we need it?
Based on the design of dual-batch overlap proposed by Deepseek team and
also the implementation of fused moe in VLLM project, we implement the
multi-stream(also known as dual-batch) overlap for deepseek+mla on
Ascend NPU. We split the input batch of model into two microbatches and
then overlap the comp/comm ops in attention and moe layers using two
streams to improve the performance. Our approach can be easily extended
when adding dispatch/combine communications for moe layer.
Compared with the previously proposed
[draft](https://github.com/vllm-project/vllm-ascend/pull/842), we use
one stream for computation ops and the other for communication ops,
separately. In out opinions, it is beneficial for arranging the order of
executing different ops and thus avoiding the contention of
computation/communication resources.

ref: [overlap for
llama](https://github.com/vllm-project/vllm/pull/15787/files)
ref: [dbo in
sglang](https://github.com/sgl-project/sglang/pull/4068/files#diff-b4937569fc71f6ad215181b633b2f89c7183a2b4ac39e41fc22635599a9be7de)

### Does this PR introduce _any_ user-facing change?
Adding an env variable "VLLM_ENABLE_DBO". Users can enable dbo by
setting "VLLM_ASCEND_ENABLE_DBO=1"
See /examples/offline_dualbatch_overlap_npu.py for more info.

### How was this patch tested?

This patch can be tested with vllm-0.9.0 using its online service with
benchmark tests. We have decoupled the func of dbo from vllm and it
should be able to run without any modification to the code of vllm(some
modifications is better to implement in vllm though).



Any advice/discussion is welcome.

### Performance Benchmark

We have ran the benchmark_serving script of vllm to test the performance
after using dual-batch overlap.

`python -m vllm.entrypoints.openai.api_server \
 --model=DeepSeek-R1-W8A8 \
 --trust-remote-code \
 --distributed-executor-backend=mp \
 -tp=16 \
 --port 8006 \
 --max-num-seqs 390 \
 --max-model-len 32768 \
 --max-num-batched-tokens 65536 \
 --block-size 128 \
 --compilation_config 0 \
 --gpu-memory-utilization 0.90 \
 --disable-log-requests \
--additional-config
'{"expert_tensor_parallel_size":1,"enable_inter_dp_scheduling":true,"init_torchair_graph_batch_sizes":true,"trace_recompiles":true,"ascend_scheduler_config":{},"enable_graph_mode":false}'`

and run benchmark with the parameters of :
`--dataset-name random --random-input-len 4096 --random-output-len 1
--num-prompts 200 --max-concurrency 8 --request-rate 5
--metric-percentiles 90`

1. test with the version using allgather+allreduce in Ascend 910B (tp16
ep16 + deepseek r1 w8a8)

2. test with the version using alltoall: 

prefill qps: 0.90 -> 1.01
Mean TTFT:8226->7432ms

The overlap approach when using alltoall communication can be further
optimized by overlapping micro-batch1's moe comp with micro-batch2's
dispatch a2a comm

---------

Signed-off-by: zhuohuan <zxdu1997@gmail.com>
2025-06-07 16:46:58 +08:00

68 lines
1.8 KiB
Python

from contextlib import contextmanager
from typing import Any
_ms_comm_context: Any = None
_cur_micro_batch_num: int = -1
_ms_layer_index_context: int = -1
_ms_metadata_context: Any = None
_ms_attn_metadata_context: Any = None
def set_multistream_layer_context(start_layer: int, ms_metadata: Any,
attn_metadata: Any):
"""
set multistream layer context before transformer layers
"""
global _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context
_ms_layer_index_context = start_layer
_ms_metadata_context = ms_metadata
_ms_attn_metadata_context = attn_metadata
def reset_multistream_layer_context():
"""
reset multistream layer context
"""
global _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context
_ms_layer_index_context = -1
_ms_metadata_context = None
_ms_attn_metadata_context = None
def get_multistream_layer_context():
"""
get multistream layer context
"""
return _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context
def advance_step_multistream_layer_context():
"""
advance multistream layer index context
"""
global _ms_layer_index_context
_ms_layer_index_context += 1
def get_multistream_comm_context() -> Any:
"""Get the current comm forward context."""
return _ms_comm_context
def get_multistream_microbatch_context() -> int:
return _cur_micro_batch_num
@contextmanager
def set_multistream_context(context: Any, micro_batch_num: int):
"""A context manager that stores the current comm forward context,
can be attention metadata, etc."""
global _ms_comm_context, _cur_micro_batch_num
_ms_comm_context = context
_cur_micro_batch_num = micro_batch_num
try:
yield
finally:
_ms_comm_context = None
_cur_micro_batch_num = -1