Files
xc-llm-ascend/vllm_ascend/ascend_forward_context.py
Levi 9976e685b7 [Bugfix][eager][oom] fix rank0 load imbalance by no padding when multi dp (#7297)
### What this PR does / why we need it?
Fix multi dp padding logic for eager mode, bacause its will cause rank0
load imbalance in kimi-k2.5-w4a8 with the all the padding tokens router
to rank0. And the fix can also apply to other model in multi dp.
- before
hbm usage:
<img width="2229" height="733" alt="image"
src="https://github.com/user-attachments/assets/50479b6d-cfd0-4206-8e80-974024652997"
/>

preformance:
```shell
Concurrency  NumPrompts   QPS          TTFT_Avg     TTFT_P50     TPOT_Avg     TPOT_P50     TPOT_P90    
============ ============ ============ ============ ============ ============ ============ ============
1            15           0.0179       1667.7803    1673.3437    35.2973      35.2775      35.3784     
32           480          0.4725       2764.8027    1905.2137    40.8030      40.6978      41.0179     
64           960          0.7820       4123.7096    3485.6153    48.0461      48.1598      48.2971     
100          1500         1.0852       6216.7988    5714.0082    52.9323      53.0613      54.6304     
108          1620         1.1040       6277.4892    5798.7425    56.3862      56.9224      57.2901     
116          1740         1.1680       6563.3293    6039.5659    56.9894      57.4027      57.5786     
128          1920         1.2555       7822.5551    7604.1662    57.7660      58.1768      58.2717     
192          2880         1.4314       9212.1953    9131.3461    58.9905      59.1683      59.2791     
256          3840         1.4480       9028.0812    8913.7937    59.0092      59.2385      59.3516  
```


- after
hbm usage:
<img width="2246" height="1005" alt="image"
src="https://github.com/user-attachments/assets/d0936481-5a58-4bc5-a6f1-b92735d47885"
/>


preformance:
```shell
Concurrency  NumPrompts   QPS          TTFT_Avg     TTFT_P50     TPOT_Avg     TPOT_P50     TPOT_P90    
============ ============ ============ ============ ============ ============ ============ ============
1            15           0.0181       601.4171     600.9774     35.6270      35.6254      35.6480     
32           480          0.4455       720.8782     724.2889     45.4250      45.4755      45.6318     
64           960          0.8445       729.6209     728.2149     47.0464      47.0896      47.1985     
100          1500         1.2601       723.4834     724.6673     48.3108      48.3844      48.5355     
108          1620         1.3409       727.1509     720.6772     48.8962      48.9409      49.0489     
116          1740         1.4080       679.9799     677.6119     49.1253      49.1983      49.3087     
128          1920         1.4155       680.6284     674.9436     49.2193      49.2450      49.3763     
192          2880         1.4422       684.6577     676.7833     49.2059      49.2264      49.3229     
256          3840         1.4558       685.2462     678.1709     49.2191      49.2351      49.3419    
```
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.17.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: fny-coder <985619145@qq.com>
2026-03-23 17:05:02 +08:00

341 lines
14 KiB
Python

import math
from contextlib import contextmanager
from enum import Enum
from typing import Any
import torch
import vllm.envs as envs_vllm
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed import get_dp_group, get_ep_group, get_tensor_model_parallel_world_size
from vllm.forward_context import BatchDescriptor, get_forward_context, set_forward_context
import vllm_ascend.envs as envs_ascend
from vllm_ascend.utils import (
AscendDeviceType,
enable_sp,
flashcomm2_enable,
get_ascend_device_type,
has_layer_idx,
is_drafter_moe_model,
is_moe_model,
speculative_enable_dispatch_gmm_combine_decode,
vllm_version_is,
)
class MoECommType(Enum):
ALLGATHER = 0
MC2 = 1
ALLTOALL = 2
FUSED_MC2 = 3
@contextmanager
def set_ascend_forward_context(
attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0,
num_tokens: int = 0,
num_tokens_across_dp: torch.Tensor | None = None,
in_profile_run: bool = False,
num_actual_tokens: int | None = None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: BatchDescriptor | None = None,
model_instance: torch.nn.Module = None,
is_draft_model=False,
skip_compiled: bool = False,
max_tokens_across_pcp: int = 0,
draft_attn_metadatas=None,
):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
We add some additional param into forward_context.
"""
forward_context_kwargs = {
"attn_metadata": attn_metadata,
"vllm_config": vllm_config,
"num_tokens": num_tokens,
"num_tokens_across_dp": num_tokens_across_dp,
"cudagraph_runtime_mode": aclgraph_runtime_mode,
"batch_descriptor": batch_descriptor,
"skip_compiled": skip_compiled,
}
if vllm_version_is("0.18.0"):
forward_context_kwargs["virtual_engine"] = virtual_engine
with set_forward_context(**forward_context_kwargs):
forward_context = get_forward_context()
forward_context.draft_attn_metadatas = draft_attn_metadatas
from vllm_ascend.ops.fused_moe.moe_comm_method import get_moe_comm_method
max_num_tokens = int(num_tokens_across_dp.max().item()) if num_tokens_across_dp is not None else num_tokens
moe_comm_type = select_moe_comm_method(max_num_tokens, vllm_config, is_draft_model)
forward_context.moe_comm_type = moe_comm_type
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)
tp_world_size = get_tensor_model_parallel_world_size()
forward_context.in_profile_run = in_profile_run
# NOTE: This cannot be set using set_forward_context
# due to multiple warmups before actual capturing
forward_context.capturing = False
# TODO: remove it when torch_npu.npu_mm_reduce_scatter_base supports tp_size >= 16.
mmrs_fusion = tp_world_size <= 8
# set for sequence parallelism, 1000 is the batch size concurrency threshold
# for enabling the flashcomm_v1 or sequence_parallelism feature.
# Currently, it is an empirical value. In normal scenarios, if the concurrency
# exceeds this threshold, the performance benefits can be maximized.
# Conversely, if the concurrency is below the threshold,
# the performance may degrade due to the switching of communication methods.
# main model and drafter model may have different architecture
is_context_moe_model = is_drafter_moe_model(vllm_config) if is_draft_model else is_moe_model(vllm_config)
if is_context_moe_model:
flash_comm_v1_enabled = enable_sp(vllm_config) and num_tokens is not None
mmrs_fusion = False
elif is_draft_model:
# TODO: for dense drafter, `sp` is redundant and is not compatible with `dp` and `graph`.
# Disable it to avoid more problems.
flash_comm_v1_enabled = False
else:
flash_comm_v1_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000
forward_context.mmrs_fusion = mmrs_fusion
forward_context.num_tokens = num_tokens
forward_context.flash_comm_v1_enabled = flash_comm_v1_enabled
# TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
forward_context.flashcomm_v2_enabled = flashcomm2_enable() and tp_world_size > 1 and num_tokens is not None
forward_context.pad_size = 0
if forward_context.flash_comm_v1_enabled or forward_context.flashcomm_v2_enabled:
pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size
forward_context.pad_size = pad_size
# set this for rope forward_oot using
forward_context.is_first_layer = True
# set layer_idx to enable optimization features that depend on this information.
# This is only applicable to models that contain these necessary attributes.
forward_context.layer_idx = None
if has_layer_idx(model_instance):
forward_context.layer_idx = model_instance.model.start_layer
forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False
forward_context.model_instance = model_instance
forward_context.is_draft_model = is_draft_model
if num_tokens is None and attn_metadata is not None:
num_tokens = attn_metadata.num_actual_tokens
dp_world_size = get_dp_group().world_size
if dp_world_size > 1 and forward_context.dp_metadata is not None:
max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item()
if forward_context.flash_comm_v1_enabled or forward_context.flashcomm_v2_enabled:
padded_length = (max_tokens_across_dp + tp_world_size - 1) // tp_world_size * tp_world_size
pad_size = padded_length - num_tokens
forward_context.padded_length = padded_length
forward_context.pad_size = pad_size
else:
max_tokens_across_dp = num_tokens
forward_context.max_tokens_across_dp = max_tokens_across_dp
forward_context.max_tokens_across_pcp = max_tokens_across_pcp
if num_tokens is not None:
if num_actual_tokens is None:
num_actual_tokens = num_tokens
# NOTE: token num which need to pad to when mc2
forward_context.padded_num_tokens = math.ceil(max_tokens_across_dp / tp_world_size) * tp_world_size
reserved_mc2_mask = get_mc2_mask()
if reserved_mc2_mask is not None:
mc2_mask = reserved_mc2_mask[: forward_context.padded_num_tokens]
mc2_mask[:num_actual_tokens] = True
mc2_mask[num_actual_tokens:] = False
forward_context.mc2_mask = mc2_mask
try:
yield
finally:
pass
_mc2_tokens_capacity: int | None = None
_reserved_mc2_mask: torch.Tensor | None = None
def set_mc2_tokens_capacity(vllm_config, max_num_reqs, uniform_decode_query_len):
global _mc2_tokens_capacity
if _mc2_tokens_capacity is not None:
return
if vllm_config.compilation_config.cudagraph_capture_sizes:
max_num_tokens = vllm_config.compilation_config.max_cudagraph_capture_size
else:
# NOTE: To save memory, we cap the max number of tokens to 512.
max_num_tokens = min(max_num_reqs * uniform_decode_query_len, 512)
tp_size = vllm_config.parallel_config.tensor_parallel_size
# Use integer arithmetic for ceiling division.
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
_mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
def get_mc2_tokens_capacity():
return _mc2_tokens_capacity
def set_mc2_mask(vllm_config, device):
global _reserved_mc2_mask
if _reserved_mc2_mask is not None:
return
if is_moe_model(vllm_config):
_reserved_mc2_mask = torch.zeros(get_mc2_tokens_capacity(), dtype=torch.bool, device=device)
else:
_reserved_mc2_mask = None
def get_mc2_mask():
return _reserved_mc2_mask
def select_moe_comm_method(num_tokens: int, vllm_config: VllmConfig, is_draft_model=False) -> MoECommType | None:
"""Select the MoE communication method according to parallel settings,
device generation, token count, and quantization.
1. Non-MoE models return `None`.
2. Without expert parallel, fall back to all-gather.
3. On A2 with expert parallel, pick MC2 when tokens fit the MC2 capacity
and the DP size is large enough; otherwise use all-gather.
4. On A3 with expert parallel, prefer fused MC2 when using w8a8_dynamic
quantization with small EP size, no dynamic_eplb, and not in MTP
mode; otherwise use MC2 within capacity or all-to-all.
5. On 310P, always use all-gather.
Args:
num_tokens (int): The number of tokens in the current batch.
vllm_config (VllmConfig): Runtime configuration for the model.
is_draft_model (bool): Whether the model runs in MTP mode (disables fused MC2).
Raises:
ValueError: If the soc version is unsupported.
Returns:
MoECommType | None: The selected MoE communication method.
"""
if not is_moe_model(vllm_config):
return None
mc2_tokens_capacity = get_mc2_tokens_capacity()
soc_version = get_ascend_device_type()
quant_type = getattr(
vllm_config.model_config.hf_text_config,
"moe_quantize",
getattr(vllm_config.model_config.hf_text_config, "quantize", None),
)
if not vllm_config.parallel_config.enable_expert_parallel or get_ep_group().world_size == 1:
moe_comm_type = MoECommType.ALLGATHER
elif soc_version in {AscendDeviceType.A2}:
num_experts = vllm_config.model_config.get_num_experts()
ep_world_size = (
vllm_config.parallel_config.world_size_across_dp // vllm_config.parallel_config.pipeline_parallel_size
)
num_experts_per_device = num_experts // ep_world_size
if num_experts_per_device <= 24 and ep_world_size >= 16 and num_tokens <= mc2_tokens_capacity:
moe_comm_type = MoECommType.MC2
else:
moe_comm_type = MoECommType.ALLGATHER
elif soc_version in {AscendDeviceType.A3}:
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
# TODO: drop speculative method guard when dispatch_gmm_combine_decode supports w16a16
fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2
dispatch_ffn_combine_enable = get_ep_group().world_size <= 32 and (not is_draft_model)
if num_tokens <= mc2_tokens_capacity:
fused_decode_enable = fused_mc2_enable
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
fused_decode_enable = fused_mc2_enable and dispatch_ffn_combine_enable
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
fused_decode_enable = (
fused_mc2_enable
and speculative_enable_dispatch_gmm_combine_decode(vllm_config)
and quant_type == "w8a8_dynamic"
)
moe_comm_type = MoECommType.FUSED_MC2 if fused_decode_enable else MoECommType.MC2
else:
fused_prefill_enable = fused_mc2_enable
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
fused_prefill_enable = fused_mc2_enable and dispatch_ffn_combine_enable
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
fused_prefill_enable = False
moe_comm_type = MoECommType.FUSED_MC2 if fused_prefill_enable else MoECommType.ALLTOALL
elif soc_version in {AscendDeviceType._310P}:
moe_comm_type = MoECommType.ALLGATHER
elif soc_version in {AscendDeviceType.A5}:
if num_tokens <= mc2_tokens_capacity and vllm_config.parallel_config.world_size_across_dp > 1:
moe_comm_type = MoECommType.MC2
else:
moe_comm_type = MoECommType.ALLTOALL
else:
raise ValueError(f"Unsupported soc_version: {soc_version}")
return moe_comm_type
class _ExtraForwardContextProxy:
"""Unified forward-context access for v1/v2 model runners."""
extra_attrs = (
"capturing",
"moe_comm_type",
"moe_comm_method",
"mmrs_fusion",
"num_tokens",
"flash_comm_v1_enabled",
"flashcomm_v2_enabled",
"pad_size",
"padded_length",
"num_tokens_across_dp",
"mc2_mask",
"is_draft_model",
"prefetch_mlp_gate_up_proj",
"prefetch_mlp_down_proj",
"model_instance",
"layer_idx",
"max_tokens_across_dp",
"max_tokens_across_pcp",
"num_accept_tokens",
"in_profile_run",
"padded_num_tokens",
)
def check_extra_attr(self, name: str):
if name not in self.extra_attrs:
raise AttributeError(
f"{name} is not extra forward context attribute, "
"please get/set it from vllm's _forward_context directly."
)
@staticmethod
def _ctx():
return get_forward_context()
def __getattr__(self, name: str) -> Any:
self.check_extra_attr(name)
ctx = self._ctx()
if envs_vllm.VLLM_USE_V2_MODEL_RUNNER:
return ctx.additional_kwargs[name]
return getattr(ctx, name)
def __setattr__(self, name: str, value: Any) -> None:
self.check_extra_attr(name)
ctx = self._ctx()
if envs_vllm.VLLM_USE_V2_MODEL_RUNNER:
ctx.additional_kwargs[name] = value
else:
setattr(ctx, name, value)
# usage: from vllm_ascend.ascend_forward_context import _EXTRA_CTX
_EXTRA_CTX = _ExtraForwardContextProxy()