Files
xc-llm-ascend/vllm_ascend/ascend_forward_context.py
drslark 384d84c7ef [Bugfix] Avoided a bug of drafter when dp and sp are enabled (#6226)
### What this PR does / why we need it?

Avoided a bug of drafter when `dp` and `sp` are enabled.

Specifically, disable `sp` when drafter is dense.

### Does this PR introduce _any_ user-facing change?

N/A

### How was this patch tested?

An aisbench test:

```shell
python3 aisbench_test.py --input_len 3500 --output_len 1000 --data_num 100 --concurrency 320 --request_rate 8
```

The result is okay.

```text
[2026-01-24 22:38:20,256] [ais_bench.benchmark.openicl.icl_inferencer.icl_gen_inferencer] [INFO] Calculate global interval offsets time: 0.5922 s
01/24 22:38:20 - AISBench - INFO - Process 0 using precomputed sleep offsets with 100 requests
Process-0 pid:220279: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [09:40<00:00,  5.81s/it]
Pid:      220279 | Post:        100 | Received:    100 | Failed:        0 | Post Time:12.51s | Receive Time:580.92s: 
Encoding output text...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 93.75it/s]
01/24 22:48:02 - AISBench - INFO - Start converting origin data to detailed data ...
01/24 22:48:02 - AISBench - INFO - Finish converting origin data to detailed data█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 95.08it/s]
01/24 22:48:02 - AISBench - INFO - Added 'Actual RPS: After Excluding Anomalies' to group 'Time - RPS: ' in legend explanation table
01/24 22:48:02 - AISBench - INFO - Successfully merged chart into position (1, 1)
01/24 22:48:02 - AISBench - INFO - RPS distribution charts saved to outputs/default/20260124_223809/performances/vllm-api-stream-chat/gsm8kdataset_rps_distribution_plot_with_actual_rps.html
01/24 22:48:02 - AISBench - INFO - Updated chart with actual RPS saved to outputs/default/20260124_223809/performances/vllm-api-stream-chat/gsm8kdataset_rps_distribution_plot_with_actual_rps.html
[2026-01-24 22:48:02,557] [ais_bench.benchmark.openicl.icl_inferencer.icl_gen_perf_inferencer] [INFO] Start extracting pref datas ...
[2026-01-24 22:48:02,558] [ais_bench.benchmark.openicl.icl_inferencer.icl_gen_perf_inferencer] [INFO] Finish extracting pref datas!
[2026-01-24 22:48:02,558] [ais_bench.benchmark.openicl.icl_inferencer.icl_gen_perf_inferencer] [INFO] Dumping detail perf data ...
Dumping data to h5: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 75.31it/s]
[2026-01-24 22:48:02,588] [ais_bench.benchmark.openicl.icl_inferencer.icl_gen_perf_inferencer] [INFO] Dump detail perf data cost: 0.02995561994612217(s)
[2026-01-24 22:48:02,588] [ais_bench.benchmark.openicl.icl_inferencer.icl_gen_perf_inferencer] [INFO] Performance task finished, results saved in outputs/default/20260124_223809/performances/vllm-api-stream-chat
01/24 22:48:02 - AISBench - INFO - time elapsed: 586.32s
Running tasks: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [09:55<00:00, 595.91s/it]
01/24 22:48:05 - AISBench - INFO - Performance evaluation tasks completed.
01/24 22:48:05 - AISBench - INFO - Loading detail perf data of model='vllm-api-stream-chat' dataset='gsm8kdataset' ...
01/24 22:48:05 - AISBench - INFO - Starting request timeline processing...
01/24 22:48:05 - AISBench - INFO - Data preprocessing completed in 0.0004s
01/24 22:48:05 - AISBench - INFO - Generating timeline traces for 100 requests...
01/24 22:48:05 - AISBench - INFO - Generated timeline trace chunks in 0.0441s
01/24 22:48:05 - AISBench - INFO - Generating concurrency traces...
01/24 22:48:05 - AISBench - INFO - Generated concurrency trace chunks in 0.0011s
01/24 22:48:05 - AISBench - INFO - Creating figure layout...
01/24 22:48:05 - AISBench - INFO - Figure layout created in 0.0504s
01/24 22:48:05 - AISBench - INFO - Writing to outputs/default/20260124_223809/performances/vllm-api-stream-chat/gsm8kdataset_plot.html...
01/24 22:48:05 - AISBench - INFO - HTML written in 0.0181s
01/24 22:48:05 - AISBench - INFO - Completed! Total execution time: 0.1148s
01/24 22:48:05 - AISBench - INFO - The gsm8kdataset_plot has been saved in outputs/default/20260124_223809/performances/vllm-api-stream-chat/gsm8kdataset_plot.html
01/24 22:48:05 - AISBench - INFO - Converting perf results of stage ...
01/24 22:48:05 - AISBench - INFO - Finish Converting!
01/24 22:48:05 - AISBench - INFO - Start calculating metrics ...
01/24 22:48:05 - AISBench - INFO - Start calculating common metrics ...
01/24 22:48:05 - AISBench - INFO - Start calculating add units ...
01/24 22:48:05 - AISBench - INFO - Finish calculating perf data!
01/24 22:48:05 - AISBench - INFO - Summarizing performance results...
01/24 22:48:05 - AISBench - INFO - Performance Results of task: vllm-api-stream-chat/gsm8kdataset: 
╒══════════════════════════╤═════════╤════════════════╤════════════════╤════════════════╤════════════════╤════════════════╤════════════════╤════════════════╤═════╕
│ Performance Parameters   │ Stage   │ Average        │ Min            │ Max            │ Median         │ P75            │ P90            │ P99            │  N  │
╞══════════════════════════╪═════════╪════════════════╪════════════════╪════════════════╪════════════════╪════════════════╪════════════════╪════════════════╪═════╡
│ E2EL                     │ total   │ 300806.1781 ms │ 189326.0489 ms │ 568345.5121 ms │ 380629.6785 ms │ 384208.3527 ms │ 385363.7709 ms │ 566871.7684 ms │ 100 │
├──────────────────────────┼─────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼─────┤
│ TTFT                     │ total   │ 107441.2231 ms │ 343.8054 ms    │ 378132.3979 ms │ 188817.4877 ms │ 190985.8451 ms │ 192547.6847 ms │ 378008.356 ms  │ 100 │
├──────────────────────────┼─────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼─────┤
│ TPOT                     │ total   │ 193.5585 ms    │ 185.1008 ms    │ 197.262 ms     │ 193.8146 ms    │ 195.0803 ms    │ 196.0323 ms    │ 196.9688 ms    │ 100 │
├──────────────────────────┼─────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼─────┤
│ ITL                      │ total   │ 194.2067 ms    │ 0.0108 ms      │ 2782.7124 ms   │ 184.9998 ms    │ 194.2631 ms    │ 221.2895 ms    │ 304.363 ms     │ 100 │
├──────────────────────────┼─────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼─────┤
│ InputTokens              │ total   │ 3506.86        │ 3431.0         │ 3508.0         │ 3508.0         │ 3508.0         │ 3508.0         │ 3508.0         │ 100 │
├──────────────────────────┼─────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼─────┤
│ OutputTokens             │ total   │ 1000.0         │ 1000.0         │ 1000.0         │ 1000.0         │ 1000.0         │ 1000.0         │ 1000.0         │ 100 │
├──────────────────────────┼─────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼─────┤
│ OutputTokenThroughput    │ total   │ 3.7745 token/s │ 1.7595 token/s │ 5.2819 token/s │ 2.6272 token/s │ 5.1028 token/s │ 5.1502 token/s │ 5.2754 token/s │ 100 │
╘══════════════════════════╧═════════╧════════════════╧════════════════╧════════════════╧════════════════╧════════════════╧════════════════╧════════════════╧═════╛
╒══════════════════════════╤═════════╤══════════════════╕
│ Common Metric            │ Stage   │ Value            │
╞══════════════════════════╪═════════╪══════════════════╡
│ Benchmark Duration       │ total   │ 580456.2704 ms   │
├──────────────────────────┼─────────┼──────────────────┤
│ Total Requests           │ total   │ 100              │
├──────────────────────────┼─────────┼──────────────────┤
│ Failed Requests          │ total   │ 0                │
├──────────────────────────┼─────────┼──────────────────┤
│ Success Requests         │ total   │ 100              │
├──────────────────────────┼─────────┼──────────────────┤
│ Concurrency              │ total   │ 51.8224          │
├──────────────────────────┼─────────┼──────────────────┤
│ Max Concurrency          │ total   │ 320              │
├──────────────────────────┼─────────┼──────────────────┤
│ Request Throughput       │ total   │ 0.1723 req/s     │
├──────────────────────────┼─────────┼──────────────────┤
│ Total Input Tokens       │ total   │ 350686           │
├──────────────────────────┼─────────┼──────────────────┤
│ Prefill Token Throughput │ total   │ 32.6398 token/s  │
├──────────────────────────┼─────────┼──────────────────┤
│ Total generated tokens   │ total   │ 100000           │
├──────────────────────────┼─────────┼──────────────────┤
│ Input Token Throughput   │ total   │ 604.1558 token/s │
├──────────────────────────┼─────────┼──────────────────┤
│ Output Token Throughput  │ total   │ 172.2783 token/s │
├──────────────────────────┼─────────┼──────────────────┤
│ Total Token Throughput   │ total   │ 776.434 token/s  │
╘══════════════════════════╧═════════╧══════════════════╛
01/24 22:48:05 - AISBench - INFO - Performance Result files locate in outputs/default/20260124_223809/performances/vllm-api-stream-chat.
```
- vLLM version: v0.14.0
- vLLM main:
d68209402d

Signed-off-by: drslark <slarksblood@qq.com>
2026-01-25 17:45:29 +08:00

273 lines
11 KiB
Python

import math
from contextlib import contextmanager
from enum import Enum
from typing import Any
import torch
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed import get_dp_group, get_ep_group, get_tensor_model_parallel_world_size
from vllm.forward_context import BatchDescriptor, get_forward_context, set_forward_context
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import (
AscendDeviceType,
enable_sp,
flashcomm2_enable,
get_ascend_device_type,
has_layer_idx,
is_drafter_moe_model,
is_moe_model,
speculative_enable_dispatch_gmm_combine_decode,
)
class MoECommType(Enum):
ALLGATHER = 0
MC2 = 1
ALLTOALL = 2
FUSED_MC2 = 3
@contextmanager
def set_ascend_forward_context(
attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0,
num_tokens: int = 0,
num_tokens_across_dp: torch.Tensor | None = None,
in_profile_run: bool = False,
num_actual_tokens: int | None = None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: BatchDescriptor | None = None,
model_instance: torch.nn.Module = None,
is_draft_model=False,
):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
We add some additional param into forward_context.
"""
with set_forward_context(
attn_metadata,
vllm_config,
virtual_engine=virtual_engine,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor,
):
forward_context = get_forward_context()
from vllm_ascend.ops.fused_moe.moe_comm_method import get_moe_comm_method
moe_comm_type = select_moe_comm_method(num_tokens, vllm_config, is_draft_model)
forward_context.moe_comm_type = moe_comm_type
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)
tp_world_size = get_tensor_model_parallel_world_size()
forward_context.in_profile_run = in_profile_run
# NOTE: This cannot be set using set_forward_context
# due to multiple warmups before actual capturing
forward_context.capturing = False
# TODO: remove it when torch_npu.npu_mm_reduce_scatter_base supports tp_size >= 16.
mmrs_fusion = tp_world_size <= 8
# set for sequence parallelism, 1000 is the batch size concurrency threshold
# for enabling the flashcomm_v1 or sequence_parallelism feature.
# Currently, it is an empirical value. In normal scenarios, if the concurrency
# exceeds this threshold, the performance benefits can be maximized.
# Conversely, if the concurrency is below the threshold,
# the performance may degrade due to the switching of communication methods.
# main model and drafter model may have different architecture
is_context_moe_model = is_drafter_moe_model(vllm_config) if is_draft_model else is_moe_model(vllm_config)
if is_context_moe_model:
sp_enabled = enable_sp(vllm_config) and num_tokens is not None
mmrs_fusion = False
elif is_draft_model:
# TODO: for dense drafter, `sp` is redundant and is not compatible with `dp` and `graph`.
# Disable it to avoid more problems.
sp_enabled = False
else:
sp_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000
forward_context.mmrs_fusion = mmrs_fusion
forward_context.num_tokens = num_tokens
forward_context.sp_enabled = sp_enabled
# TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
forward_context.flashcomm_v2_enabled = flashcomm2_enable() and tp_world_size > 1 and num_tokens is not None
if forward_context.sp_enabled or forward_context.flashcomm_v2_enabled:
pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size
forward_context.pad_size = pad_size
# set this for rope forward_oot using
forward_context.is_first_layer = True
# set layer_idx to enable optimization features that depend on this information.
# This is only applicable to models that contain these necessary attributes.
forward_context.layer_idx = None
if has_layer_idx(model_instance):
forward_context.layer_idx = model_instance.model.start_layer
# TODO(rjg-lyh): refactor mlp weight prefetch method
# set for mlp weight prefetch
prefetch_mlp_enabled = (
envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP
and forward_context.layer_idx is not None
and num_tokens is not None
and num_tokens < 500
)
if prefetch_mlp_enabled:
forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
forward_context.model_instance = model_instance
forward_context.is_draft_model = is_draft_model
if num_tokens is None and attn_metadata is not None:
num_tokens = attn_metadata.num_actual_tokens
dp_world_size = get_dp_group().world_size
if dp_world_size > 1 and forward_context.dp_metadata is not None:
max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item()
if forward_context.sp_enabled or forward_context.flashcomm_v2_enabled:
padded_length = (max_tokens_across_dp + tp_world_size - 1) // tp_world_size * tp_world_size
pad_size = padded_length - num_tokens
forward_context.padded_length = padded_length
forward_context.pad_size = pad_size
else:
max_tokens_across_dp = num_tokens
forward_context.max_tokens_across_dp = max_tokens_across_dp
if num_tokens is not None:
if num_actual_tokens is None:
num_actual_tokens = num_tokens
# NOTE: token num which need to pad to when mc2
forward_context.padded_num_tokens = math.ceil(max_tokens_across_dp / tp_world_size) * tp_world_size
reserved_mc2_mask = get_mc2_mask()
if reserved_mc2_mask is not None:
mc2_mask = reserved_mc2_mask[: forward_context.padded_num_tokens]
mc2_mask[:num_actual_tokens] = True
mc2_mask[num_actual_tokens:] = False
forward_context.mc2_mask = mc2_mask
try:
yield
finally:
pass
_mc2_tokens_capacity: int | None = None
_reserved_mc2_mask: torch.Tensor | None = None
_sin: torch.Tensor | None = None
_cos: torch.Tensor | None = None
def set_mc2_tokens_capacity(vllm_config, max_num_reqs, uniform_decode_query_len):
global _mc2_tokens_capacity
if _mc2_tokens_capacity is not None:
return
if vllm_config.compilation_config.cudagraph_capture_sizes:
max_num_tokens = vllm_config.compilation_config.max_cudagraph_capture_size
else:
# NOTE: To save memory, we cap the max number of tokens to 512.
max_num_tokens = min(max_num_reqs * uniform_decode_query_len, 512)
tp_size = vllm_config.parallel_config.tensor_parallel_size
# Use integer arithmetic for ceiling division.
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
_mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
def get_mc2_tokens_capacity():
return _mc2_tokens_capacity
def set_mc2_mask(vllm_config, device):
global _reserved_mc2_mask
if _reserved_mc2_mask is not None:
return
if is_moe_model(vllm_config):
_reserved_mc2_mask = torch.zeros(get_mc2_tokens_capacity(), dtype=torch.bool, device=device)
else:
_reserved_mc2_mask = None
def get_mc2_mask():
return _reserved_mc2_mask
def select_moe_comm_method(num_tokens: int, vllm_config: VllmConfig, is_draft_model=False) -> MoECommType | None:
"""Select the MoE communication method according to parallel settings,
device generation, token count, and quantization.
1. Non-MoE models return `None`.
2. Without expert parallel, fall back to all-gather.
3. On A2 with expert parallel, pick MC2 when tokens fit the MC2 capacity
and the DP size is large enough; otherwise use all-gather.
4. On A3 with expert parallel, prefer fused MC2 when using w8a8_dynamic
quantization with small EP size, no dynamic_eplb, and not in MTP
mode; otherwise use MC2 within capacity or all-to-all.
Args:
num_tokens (int): The number of tokens in the current batch.
vllm_config (VllmConfig): Runtime configuration for the model.
is_draft_model (bool): Whether the model runs in MTP mode (disables fused MC2).
Raises:
ValueError: If the soc version is unsupported.
Returns:
MoECommType | None: The selected MoE communication method.
"""
if not is_moe_model(vllm_config):
return None
mc2_tokens_capacity = get_mc2_tokens_capacity()
soc_version = get_ascend_device_type()
quant_type = getattr(
vllm_config.model_config.hf_text_config,
"moe_quantize",
getattr(vllm_config.model_config.hf_text_config, "quantize", None),
)
if not vllm_config.parallel_config.enable_expert_parallel or get_ep_group().world_size == 1:
moe_comm_type = MoECommType.ALLGATHER
elif soc_version in {AscendDeviceType.A2}:
if (
num_tokens <= mc2_tokens_capacity
and vllm_config.parallel_config.world_size_across_dp / vllm_config.parallel_config.pipeline_parallel_size
>= 16
):
moe_comm_type = MoECommType.MC2
else:
moe_comm_type = MoECommType.ALLGATHER
elif soc_version in {AscendDeviceType.A3}:
dynamic_eplb = get_ascend_config().eplb_config.dynamic_eplb
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
# TODO: drop speculative method guard when dispatch_gmm_combine_decode supports w16a16
fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic"
dispatch_ffn_combine_enable = get_ep_group().world_size <= 32 and (not is_draft_model) and (not dynamic_eplb)
if num_tokens <= mc2_tokens_capacity:
fused_decode_enable = fused_mc2_enable
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
fused_decode_enable = fused_mc2_enable and dispatch_ffn_combine_enable
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
fused_decode_enable = fused_mc2_enable and speculative_enable_dispatch_gmm_combine_decode(vllm_config)
moe_comm_type = MoECommType.FUSED_MC2 if fused_decode_enable else MoECommType.MC2
else:
fused_prefill_enable = fused_mc2_enable
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
fused_prefill_enable = fused_mc2_enable and dispatch_ffn_combine_enable
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
fused_prefill_enable = False
moe_comm_type = MoECommType.FUSED_MC2 if fused_prefill_enable else MoECommType.ALLTOALL
else:
raise ValueError(f"Unsupported soc_version: {soc_version}")
return moe_comm_type