Files
xc-llm-ascend/vllm_ascend/ascend_forward_context.py
pu-zhe 85e33941e8 [Feat.]: 310p support MOE models (#6530)
### What this PR does / why we need it?
This pull request integrates comprehensive support for Mixture of
Experts (MoE) models on the Ascend 310P device within the vllm-ascend
framework. It achieves this by introducing specialized modules for
expert selection, fused MoE layers, and optimized all-gather
communication. The changes also refine existing NPU operations, making
them more consistent and efficient for 310P, ultimately enhancing the
performance and compatibility of MoE models on this hardware.

Highlights
310P MoE Support: Introduces dedicated implementations for Mixture of
Experts (MoE) models on Ascend 310P devices, including new modules for
expert selection, fused MoE layers, and communication.
All-Gather Communication: Enforces the use of ALLGATHER communication
for MoE operations on 310P, optimizing data transfer and leveraging
NPU-specific token dispatching.
Simplified NPU Operations: Removes conditional type casting for
npu_swiglu and enables custom rotary embedding kernels unconditionally,
suggesting improved native support for 310P.
New MoE Classes Registered: Registers AscendFusedMoE310 and
AscendSharedFusedMoE310 to integrate 310P-specific MoE layers into the
system's custom operation registry.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
offline test and server test, with qwen3-30b-a3b,tp/ep 4 on 310p

- vLLM version: v0.15.0
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0

---------

Signed-off-by: pu-zhe <zpuaa@outlook.com>
2026-02-06 10:30:56 +08:00

271 lines
11 KiB
Python

import math
from contextlib import contextmanager
from enum import Enum
from typing import Any
import torch
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed import get_dp_group, get_ep_group, get_tensor_model_parallel_world_size
from vllm.forward_context import BatchDescriptor, get_forward_context, set_forward_context
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import (
AscendDeviceType,
enable_sp,
flashcomm2_enable,
get_ascend_device_type,
has_layer_idx,
is_drafter_moe_model,
is_moe_model,
speculative_enable_dispatch_gmm_combine_decode,
)
class MoECommType(Enum):
ALLGATHER = 0
MC2 = 1
ALLTOALL = 2
FUSED_MC2 = 3
@contextmanager
def set_ascend_forward_context(
attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0,
num_tokens: int = 0,
num_tokens_across_dp: torch.Tensor | None = None,
in_profile_run: bool = False,
num_actual_tokens: int | None = None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: BatchDescriptor | None = None,
model_instance: torch.nn.Module = None,
is_draft_model=False,
skip_compiled: bool = False,
draft_attn_metadatas=None,
):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
We add some additional param into forward_context.
"""
forward_context_kwargs = {
"attn_metadata": attn_metadata,
"vllm_config": vllm_config,
"virtual_engine": virtual_engine,
"num_tokens": num_tokens,
"num_tokens_across_dp": num_tokens_across_dp,
"cudagraph_runtime_mode": aclgraph_runtime_mode,
"batch_descriptor": batch_descriptor,
"skip_compiled": skip_compiled,
}
with set_forward_context(**forward_context_kwargs):
forward_context = get_forward_context()
forward_context.draft_attn_metadatas = draft_attn_metadatas
from vllm_ascend.ops.fused_moe.moe_comm_method import get_moe_comm_method
moe_comm_type = select_moe_comm_method(num_tokens, vllm_config, is_draft_model)
forward_context.moe_comm_type = moe_comm_type
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)
tp_world_size = get_tensor_model_parallel_world_size()
forward_context.in_profile_run = in_profile_run
# NOTE: This cannot be set using set_forward_context
# due to multiple warmups before actual capturing
forward_context.capturing = False
# TODO: remove it when torch_npu.npu_mm_reduce_scatter_base supports tp_size >= 16.
mmrs_fusion = tp_world_size <= 8
# set for sequence parallelism, 1000 is the batch size concurrency threshold
# for enabling the flashcomm_v1 or sequence_parallelism feature.
# Currently, it is an empirical value. In normal scenarios, if the concurrency
# exceeds this threshold, the performance benefits can be maximized.
# Conversely, if the concurrency is below the threshold,
# the performance may degrade due to the switching of communication methods.
# main model and drafter model may have different architecture
is_context_moe_model = is_drafter_moe_model(vllm_config) if is_draft_model else is_moe_model(vllm_config)
if is_context_moe_model:
sp_enabled = enable_sp(vllm_config) and num_tokens is not None
mmrs_fusion = False
elif is_draft_model:
# TODO: for dense drafter, `sp` is redundant and is not compatible with `dp` and `graph`.
# Disable it to avoid more problems.
sp_enabled = False
else:
sp_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000
forward_context.mmrs_fusion = mmrs_fusion
forward_context.num_tokens = num_tokens
forward_context.sp_enabled = sp_enabled
# TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
forward_context.flashcomm_v2_enabled = flashcomm2_enable() and tp_world_size > 1 and num_tokens is not None
if forward_context.sp_enabled or forward_context.flashcomm_v2_enabled:
pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size
forward_context.pad_size = pad_size
# set this for rope forward_oot using
forward_context.is_first_layer = True
# set layer_idx to enable optimization features that depend on this information.
# This is only applicable to models that contain these necessary attributes.
forward_context.layer_idx = None
if has_layer_idx(model_instance):
forward_context.layer_idx = model_instance.model.start_layer
forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False
forward_context.model_instance = model_instance
forward_context.is_draft_model = is_draft_model
if num_tokens is None and attn_metadata is not None:
num_tokens = attn_metadata.num_actual_tokens
dp_world_size = get_dp_group().world_size
if dp_world_size > 1 and forward_context.dp_metadata is not None:
max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item()
if forward_context.sp_enabled or forward_context.flashcomm_v2_enabled:
padded_length = (max_tokens_across_dp + tp_world_size - 1) // tp_world_size * tp_world_size
pad_size = padded_length - num_tokens
forward_context.padded_length = padded_length
forward_context.pad_size = pad_size
else:
max_tokens_across_dp = num_tokens
forward_context.max_tokens_across_dp = max_tokens_across_dp
if num_tokens is not None:
if num_actual_tokens is None:
num_actual_tokens = num_tokens
# NOTE: token num which need to pad to when mc2
forward_context.padded_num_tokens = math.ceil(max_tokens_across_dp / tp_world_size) * tp_world_size
reserved_mc2_mask = get_mc2_mask()
if reserved_mc2_mask is not None:
mc2_mask = reserved_mc2_mask[: forward_context.padded_num_tokens]
mc2_mask[:num_actual_tokens] = True
mc2_mask[num_actual_tokens:] = False
forward_context.mc2_mask = mc2_mask
try:
yield
finally:
pass
_mc2_tokens_capacity: int | None = None
_reserved_mc2_mask: torch.Tensor | None = None
_sin: torch.Tensor | None = None
_cos: torch.Tensor | None = None
def set_mc2_tokens_capacity(vllm_config, max_num_reqs, uniform_decode_query_len):
global _mc2_tokens_capacity
if _mc2_tokens_capacity is not None:
return
if vllm_config.compilation_config.cudagraph_capture_sizes:
max_num_tokens = vllm_config.compilation_config.max_cudagraph_capture_size
else:
# NOTE: To save memory, we cap the max number of tokens to 512.
max_num_tokens = min(max_num_reqs * uniform_decode_query_len, 512)
tp_size = vllm_config.parallel_config.tensor_parallel_size
# Use integer arithmetic for ceiling division.
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
_mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
def get_mc2_tokens_capacity():
return _mc2_tokens_capacity
def set_mc2_mask(vllm_config, device):
global _reserved_mc2_mask
if _reserved_mc2_mask is not None:
return
if is_moe_model(vllm_config):
_reserved_mc2_mask = torch.zeros(get_mc2_tokens_capacity(), dtype=torch.bool, device=device)
else:
_reserved_mc2_mask = None
def get_mc2_mask():
return _reserved_mc2_mask
def select_moe_comm_method(num_tokens: int, vllm_config: VllmConfig, is_draft_model=False) -> MoECommType | None:
"""Select the MoE communication method according to parallel settings,
device generation, token count, and quantization.
1. Non-MoE models return `None`.
2. Without expert parallel, fall back to all-gather.
3. On A2 with expert parallel, pick MC2 when tokens fit the MC2 capacity
and the DP size is large enough; otherwise use all-gather.
4. On A3 with expert parallel, prefer fused MC2 when using w8a8_dynamic
quantization with small EP size, no dynamic_eplb, and not in MTP
mode; otherwise use MC2 within capacity or all-to-all.
5. On 310P, always use all-gather.
Args:
num_tokens (int): The number of tokens in the current batch.
vllm_config (VllmConfig): Runtime configuration for the model.
is_draft_model (bool): Whether the model runs in MTP mode (disables fused MC2).
Raises:
ValueError: If the soc version is unsupported.
Returns:
MoECommType | None: The selected MoE communication method.
"""
if not is_moe_model(vllm_config):
return None
mc2_tokens_capacity = get_mc2_tokens_capacity()
soc_version = get_ascend_device_type()
quant_type = getattr(
vllm_config.model_config.hf_text_config,
"moe_quantize",
getattr(vllm_config.model_config.hf_text_config, "quantize", None),
)
if not vllm_config.parallel_config.enable_expert_parallel or get_ep_group().world_size == 1:
moe_comm_type = MoECommType.ALLGATHER
elif soc_version in {AscendDeviceType.A2}:
if (
num_tokens <= mc2_tokens_capacity
and vllm_config.parallel_config.world_size_across_dp / vllm_config.parallel_config.pipeline_parallel_size
>= 16
):
moe_comm_type = MoECommType.MC2
else:
moe_comm_type = MoECommType.ALLGATHER
elif soc_version in {AscendDeviceType.A3}:
dynamic_eplb = get_ascend_config().eplb_config.dynamic_eplb
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
# TODO: drop speculative method guard when dispatch_gmm_combine_decode supports w16a16
fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic"
dispatch_ffn_combine_enable = get_ep_group().world_size <= 32 and (not is_draft_model) and (not dynamic_eplb)
if num_tokens <= mc2_tokens_capacity:
fused_decode_enable = fused_mc2_enable
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
fused_decode_enable = fused_mc2_enable and dispatch_ffn_combine_enable
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
fused_decode_enable = fused_mc2_enable and speculative_enable_dispatch_gmm_combine_decode(vllm_config)
moe_comm_type = MoECommType.FUSED_MC2 if fused_decode_enable else MoECommType.MC2
else:
fused_prefill_enable = fused_mc2_enable
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
fused_prefill_enable = fused_mc2_enable and dispatch_ffn_combine_enable
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
fused_prefill_enable = False
moe_comm_type = MoECommType.FUSED_MC2 if fused_prefill_enable else MoECommType.ALLTOALL
elif soc_version in {AscendDeviceType._310P}:
moe_comm_type = MoECommType.ALLGATHER
else:
raise ValueError(f"Unsupported soc_version: {soc_version}")
return moe_comm_type