Files
xc-llm-ascend/vllm_ascend/ascend_forward_context.py
zxr2333 5645ca8392 [BugFix]A2 MOE method&& layerwise MTP bugfix && Mamba gdn_metadata bugfix (#7364)
### What this PR does / why we need it?
Some bug fixes, mainly including:
1. For A2, the number of experts each single card cannot be greater than
16 when using MC2. The PR fixed the error in the A2 moe communication
method selection, which would cause the selection of an incorrect
communication method when the number of model experts exceeds 256. For
example, when using an A2 16-cards model to load the PD-disaggregation D
node with Qwen3.5 series models, the incorrect MC2 method would be
chosen.
2. Fixed the issue where the layerwise connector sends the kv-cache of
the MTP layer multiple times when `num_spec_tokens` > 1. Now, the
kv-cache is sent only when the MTP layer is forward for the first time.
3. Fix the accuracy issue of qwen3.5 when using MTP for PD
disaggregation. The cause is that `num_decode_draft_tokens` does not
consider that `spec_tokens` are not existed during the first inference
when PD disaggregation (`spec_tokens` are generated during the first
inference). However, `spec_tokens_padding` is added by
`recomputed_scheduler`. As a result, `gdn_metadata` incorrectly
considers that the prefill with a length of 2 is performed.
---------
Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
Signed-off-by: zxr2333 <64738772+nwpu-zxr@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-03-17 23:03:45 +08:00

337 lines
13 KiB
Python

import math
from contextlib import contextmanager
from enum import Enum
from typing import Any
import torch
import vllm.envs as envs_vllm
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed import get_dp_group, get_ep_group, get_tensor_model_parallel_world_size
from vllm.forward_context import BatchDescriptor, get_forward_context, set_forward_context
import vllm_ascend.envs as envs_ascend
from vllm_ascend.utils import (
AscendDeviceType,
enable_sp,
flashcomm2_enable,
get_ascend_device_type,
has_layer_idx,
is_drafter_moe_model,
is_moe_model,
speculative_enable_dispatch_gmm_combine_decode,
)
class MoECommType(Enum):
ALLGATHER = 0
MC2 = 1
ALLTOALL = 2
FUSED_MC2 = 3
@contextmanager
def set_ascend_forward_context(
attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0,
num_tokens: int = 0,
num_tokens_across_dp: torch.Tensor | None = None,
in_profile_run: bool = False,
num_actual_tokens: int | None = None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: BatchDescriptor | None = None,
model_instance: torch.nn.Module = None,
is_draft_model=False,
skip_compiled: bool = False,
max_tokens_across_pcp: int = 0,
draft_attn_metadatas=None,
):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
We add some additional param into forward_context.
"""
forward_context_kwargs = {
"attn_metadata": attn_metadata,
"vllm_config": vllm_config,
"virtual_engine": virtual_engine,
"num_tokens": num_tokens,
"num_tokens_across_dp": num_tokens_across_dp,
"cudagraph_runtime_mode": aclgraph_runtime_mode,
"batch_descriptor": batch_descriptor,
"skip_compiled": skip_compiled,
}
with set_forward_context(**forward_context_kwargs):
forward_context = get_forward_context()
forward_context.draft_attn_metadatas = draft_attn_metadatas
from vllm_ascend.ops.fused_moe.moe_comm_method import get_moe_comm_method
moe_comm_type = select_moe_comm_method(num_tokens, vllm_config, is_draft_model)
forward_context.moe_comm_type = moe_comm_type
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)
tp_world_size = get_tensor_model_parallel_world_size()
forward_context.in_profile_run = in_profile_run
# NOTE: This cannot be set using set_forward_context
# due to multiple warmups before actual capturing
forward_context.capturing = False
# TODO: remove it when torch_npu.npu_mm_reduce_scatter_base supports tp_size >= 16.
mmrs_fusion = tp_world_size <= 8
# set for sequence parallelism, 1000 is the batch size concurrency threshold
# for enabling the flashcomm_v1 or sequence_parallelism feature.
# Currently, it is an empirical value. In normal scenarios, if the concurrency
# exceeds this threshold, the performance benefits can be maximized.
# Conversely, if the concurrency is below the threshold,
# the performance may degrade due to the switching of communication methods.
# main model and drafter model may have different architecture
is_context_moe_model = is_drafter_moe_model(vllm_config) if is_draft_model else is_moe_model(vllm_config)
if is_context_moe_model:
flash_comm_v1_enabled = enable_sp(vllm_config) and num_tokens is not None
mmrs_fusion = False
elif is_draft_model:
# TODO: for dense drafter, `sp` is redundant and is not compatible with `dp` and `graph`.
# Disable it to avoid more problems.
flash_comm_v1_enabled = False
else:
flash_comm_v1_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000
forward_context.mmrs_fusion = mmrs_fusion
forward_context.num_tokens = num_tokens
forward_context.flash_comm_v1_enabled = flash_comm_v1_enabled
# TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
forward_context.flashcomm_v2_enabled = flashcomm2_enable() and tp_world_size > 1 and num_tokens is not None
forward_context.pad_size = 0
if forward_context.flash_comm_v1_enabled or forward_context.flashcomm_v2_enabled:
pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size
forward_context.pad_size = pad_size
# set this for rope forward_oot using
forward_context.is_first_layer = True
# set layer_idx to enable optimization features that depend on this information.
# This is only applicable to models that contain these necessary attributes.
forward_context.layer_idx = None
if has_layer_idx(model_instance):
forward_context.layer_idx = model_instance.model.start_layer
forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False
forward_context.model_instance = model_instance
forward_context.is_draft_model = is_draft_model
if num_tokens is None and attn_metadata is not None:
num_tokens = attn_metadata.num_actual_tokens
dp_world_size = get_dp_group().world_size
if dp_world_size > 1 and forward_context.dp_metadata is not None:
max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item()
if forward_context.flash_comm_v1_enabled or forward_context.flashcomm_v2_enabled:
padded_length = (max_tokens_across_dp + tp_world_size - 1) // tp_world_size * tp_world_size
pad_size = padded_length - num_tokens
forward_context.padded_length = padded_length
forward_context.pad_size = pad_size
else:
max_tokens_across_dp = num_tokens
forward_context.max_tokens_across_dp = max_tokens_across_dp
forward_context.max_tokens_across_pcp = max_tokens_across_pcp
if num_tokens is not None:
if num_actual_tokens is None:
num_actual_tokens = num_tokens
# NOTE: token num which need to pad to when mc2
forward_context.padded_num_tokens = math.ceil(max_tokens_across_dp / tp_world_size) * tp_world_size
reserved_mc2_mask = get_mc2_mask()
if reserved_mc2_mask is not None:
mc2_mask = reserved_mc2_mask[: forward_context.padded_num_tokens]
mc2_mask[:num_actual_tokens] = True
mc2_mask[num_actual_tokens:] = False
forward_context.mc2_mask = mc2_mask
try:
yield
finally:
pass
_mc2_tokens_capacity: int | None = None
_reserved_mc2_mask: torch.Tensor | None = None
def set_mc2_tokens_capacity(vllm_config, max_num_reqs, uniform_decode_query_len):
global _mc2_tokens_capacity
if _mc2_tokens_capacity is not None:
return
if vllm_config.compilation_config.cudagraph_capture_sizes:
max_num_tokens = vllm_config.compilation_config.max_cudagraph_capture_size
else:
# NOTE: To save memory, we cap the max number of tokens to 512.
max_num_tokens = min(max_num_reqs * uniform_decode_query_len, 512)
tp_size = vllm_config.parallel_config.tensor_parallel_size
# Use integer arithmetic for ceiling division.
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
_mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
def get_mc2_tokens_capacity():
return _mc2_tokens_capacity
def set_mc2_mask(vllm_config, device):
global _reserved_mc2_mask
if _reserved_mc2_mask is not None:
return
if is_moe_model(vllm_config):
_reserved_mc2_mask = torch.zeros(get_mc2_tokens_capacity(), dtype=torch.bool, device=device)
else:
_reserved_mc2_mask = None
def get_mc2_mask():
return _reserved_mc2_mask
def select_moe_comm_method(num_tokens: int, vllm_config: VllmConfig, is_draft_model=False) -> MoECommType | None:
"""Select the MoE communication method according to parallel settings,
device generation, token count, and quantization.
1. Non-MoE models return `None`.
2. Without expert parallel, fall back to all-gather.
3. On A2 with expert parallel, pick MC2 when tokens fit the MC2 capacity
and the DP size is large enough; otherwise use all-gather.
4. On A3 with expert parallel, prefer fused MC2 when using w8a8_dynamic
quantization with small EP size, no dynamic_eplb, and not in MTP
mode; otherwise use MC2 within capacity or all-to-all.
5. On 310P, always use all-gather.
Args:
num_tokens (int): The number of tokens in the current batch.
vllm_config (VllmConfig): Runtime configuration for the model.
is_draft_model (bool): Whether the model runs in MTP mode (disables fused MC2).
Raises:
ValueError: If the soc version is unsupported.
Returns:
MoECommType | None: The selected MoE communication method.
"""
if not is_moe_model(vllm_config):
return None
mc2_tokens_capacity = get_mc2_tokens_capacity()
soc_version = get_ascend_device_type()
quant_type = getattr(
vllm_config.model_config.hf_text_config,
"moe_quantize",
getattr(vllm_config.model_config.hf_text_config, "quantize", None),
)
if not vllm_config.parallel_config.enable_expert_parallel or get_ep_group().world_size == 1:
moe_comm_type = MoECommType.ALLGATHER
elif soc_version in {AscendDeviceType.A2}:
num_experts = vllm_config.model_config.get_num_experts()
ep_world_size = (
vllm_config.parallel_config.world_size_across_dp // vllm_config.parallel_config.pipeline_parallel_size
)
num_experts_per_device = num_experts // ep_world_size
if num_experts_per_device <= 24 and ep_world_size >= 16 and num_tokens <= mc2_tokens_capacity:
moe_comm_type = MoECommType.MC2
else:
moe_comm_type = MoECommType.ALLGATHER
elif soc_version in {AscendDeviceType.A3}:
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
# TODO: drop speculative method guard when dispatch_gmm_combine_decode supports w16a16
fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2
dispatch_ffn_combine_enable = get_ep_group().world_size <= 32 and (not is_draft_model)
if num_tokens <= mc2_tokens_capacity:
fused_decode_enable = fused_mc2_enable
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
fused_decode_enable = fused_mc2_enable and dispatch_ffn_combine_enable
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
fused_decode_enable = (
fused_mc2_enable
and speculative_enable_dispatch_gmm_combine_decode(vllm_config)
and quant_type == "w8a8_dynamic"
)
moe_comm_type = MoECommType.FUSED_MC2 if fused_decode_enable else MoECommType.MC2
else:
fused_prefill_enable = fused_mc2_enable
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
fused_prefill_enable = fused_mc2_enable and dispatch_ffn_combine_enable
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
fused_prefill_enable = False
moe_comm_type = MoECommType.FUSED_MC2 if fused_prefill_enable else MoECommType.ALLTOALL
elif soc_version in {AscendDeviceType._310P}:
moe_comm_type = MoECommType.ALLGATHER
elif soc_version in {AscendDeviceType.A5}:
if num_tokens <= mc2_tokens_capacity and vllm_config.parallel_config.world_size_across_dp > 1:
moe_comm_type = MoECommType.MC2
else:
moe_comm_type = MoECommType.ALLTOALL
else:
raise ValueError(f"Unsupported soc_version: {soc_version}")
return moe_comm_type
class _ExtraForwardContextProxy:
"""Unified forward-context access for v1/v2 model runners."""
extra_attrs = (
"capturing",
"moe_comm_type",
"moe_comm_method",
"mmrs_fusion",
"num_tokens",
"flash_comm_v1_enabled",
"flashcomm_v2_enabled",
"pad_size",
"padded_length",
"num_tokens_across_dp",
"mc2_mask",
"is_draft_model",
"prefetch_mlp_gate_up_proj",
"prefetch_mlp_down_proj",
"model_instance",
"layer_idx",
"max_tokens_across_dp",
"max_tokens_across_pcp",
"num_accept_tokens",
"in_profile_run",
"padded_num_tokens",
)
def check_extra_attr(self, name: str):
if name not in self.extra_attrs:
raise AttributeError(
f"{name} is not extra forward context attribute, "
"please get/set it from vllm's _forward_context directly."
)
@staticmethod
def _ctx():
return get_forward_context()
def __getattr__(self, name: str) -> Any:
self.check_extra_attr(name)
ctx = self._ctx()
if envs_vllm.VLLM_USE_V2_MODEL_RUNNER:
return ctx.additional_kwargs[name]
return getattr(ctx, name)
def __setattr__(self, name: str, value: Any) -> None:
self.check_extra_attr(name)
ctx = self._ctx()
if envs_vllm.VLLM_USE_V2_MODEL_RUNNER:
ctx.additional_kwargs[name] = value
else:
setattr(ctx, name, value)
# usage: from vllm_ascend.ascend_forward_context import _EXTRA_CTX
_EXTRA_CTX = _ExtraForwardContextProxy()