Files
xc-llm-ascend/vllm_ascend/ascend_forward_context.py

200 lines
8.3 KiB
Python
Raw Normal View History

import math
from contextlib import contextmanager
from enum import Enum
from typing import TYPE_CHECKING, Any, Optional
import torch
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed import (get_dp_group, get_ep_group,
get_tensor_model_parallel_world_size)
from vllm.forward_context import (BatchDescriptor, get_forward_context,
set_forward_context)
import vllm_ascend.envs as envs_ascend
from vllm_ascend.utils import (enable_sp, flashcomm2_enable, has_layer_idx,
is_moe_model)
if TYPE_CHECKING:
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
else:
WeightPrefetchMethod = None
class FusedMoEState(Enum):
AllGather = 0
All2All = 1
MC2 = 2
AllGatherEP = 3
NaiveMulticast = 4
2025-08-02 09:49:10 +08:00
All2AllSeq = 5
def get_fused_moe_state(ep_size: int, with_prefill: bool,
is_deepseek_v3_r1: bool):
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
# only supports deepseek v3/r1
if (envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
and is_deepseek_v3_r1):
return FusedMoEState.AllGatherEP
elif ep_size == 1:
if with_prefill:
return FusedMoEState.NaiveMulticast
else:
return FusedMoEState.AllGather
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
elif ep_size < 16 or with_prefill:
return FusedMoEState.All2All
else:
return FusedMoEState.MC2
class MoECommType(Enum):
ALLGATHER = 0
MC2 = 1
ALLTOALL = 2
add `dispatch_gmm_combine` kernel (#3532) ### What this PR does / why we need it? This PR introduces the Ascend implementation of the `dispatch_ffn_combine` kernel and wires it into the vLLM-Ascend runtime, together with follow‑up fixes to ensure the kernel builds and runs correctly in CI. - Add full host and device implementation of the `dispatch_ffn_combine` kernel under `csrc/dispatch_ffn_combine`, including tiling logic, MOE routing helpers, and kernel utilities for quantized FFN dispatch. - Integrate the new kernel with the PyTorch binding (csrc/torch_binding.cpp, csrc/torch_binding_meta.cpp) and the Ascend runtime (vllm_ascend/ascend_forward_context.py, vllm_ascend/worker/model_runner_v1.py). - Extend fused MoE communication and token dispatch support in `vllm_ascend/ops/fused_moe`, adding methods/utilities needed by the new dispatch path. - Update quantization logic in vllm_ascend/quantization/w8a8_dynamic.py to support the new FFN dispatch flow. - Fix kernel build issues by adjusting `csrc/build_aclnn.sh`, CMake configuration, and include/namespace usage in the new kernel files. - Add an end‑to‑end nightly test `tests/e2e/nightly/ops/test_dispatch_ffn_combine.py` and helper utilities in `vllm_ascend/utils.py` to validate the new kernel. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.12.0 --------- Signed-off-by: mojave2 <chenchen145@huawei.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
2025-12-04 23:00:59 +08:00
FUSED_ALLTOALL = 3
@contextmanager
def set_ascend_forward_context(
attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0,
num_tokens: Optional[int] = None,
num_tokens_across_dp: Optional[torch.Tensor] = None,
with_prefill: bool = True,
in_profile_run: bool = False,
reserved_mc2_mask: Optional[torch.Tensor] = None,
moe_comm_type: Optional[MoECommType] = None,
num_actual_tokens: Optional[int] = None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: Optional[BatchDescriptor] = None,
prefetch_stream: torch.npu.Stream = None,
model_instance: torch.nn.Module = None,
[Feat] Support MTP to running in full graph mode (#3892) ### What this PR does / why we need it? Currently, the MTP model still runs in eager in full graph mode. This PR adapts the MTP with the full graph capture and execution. When the graph mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to improve the performance. The change in both disable_padded_drafter_batch is True and False case include: 1. Add _mtp_graph_params in acl_graph.py to isolate the data of main model and the data of MTP. 2. Padding some metadata in mla_v1.py when in fullgraph mode. 3. Fixed the essential data address that will be used in model.forward. 4. Adapted according to the aclgraph capture framwork: 1). Rebuild MTP model with ACLGraphWrapper. 2). Add common attn metadata when start capture in MTP dummy_run. 3). Add common attn metadata update in MTP. 4). Addapted data update when num_speculative_tokens > 1. 5. Add a patch of MTP to adapt vllm v0.11.0. Existing Issues: 1. When disable_padded_drafter_batch=True and running in FullGraph mode, the data of the first-round requests in MTP is abnormal. We need to identify the cause subsequently. 2. When disable_padded_drafter_batch=False and running in FullGraph mode, the acceptance rate of the second and third tokens will decrease (For example, if we set the num_speculative_tokens=3, the acceptance rate of first token is 90%, the second is only 50% lower than 60%, the third is only 20% lower than 30%). The reason is that the data processed after the model runs does not match. This is a problem from another PR. It works fine in eager and PIECEWISE mode, but has problem in FullGraph mode. Once we have a solution, we will submit a bugfix. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 --------- Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
weight_prefetch_method: Optional[WeightPrefetchMethod] = None,
is_mtp_model=False):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
We add some additional param into forward_context.
"""
with set_forward_context(
attn_metadata,
vllm_config,
virtual_engine=virtual_engine,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor,
):
forward_context = get_forward_context()
from vllm_ascend.ops.fused_moe.moe_comm_method import \
get_moe_comm_method
forward_context.moe_comm_type = moe_comm_type
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)
forward_context.with_prefill = with_prefill
tp_world_size = get_tensor_model_parallel_world_size()
ep_size = (get_ep_group().world_size if
vllm_config.parallel_config.enable_expert_parallel else 1)
# fused_moe_state is used in torchair, it will be deleted along with torchair
is_deepseek_v3_r1 = hasattr(
vllm_config.model_config.hf_config, 'n_routed_experts'
) and vllm_config.model_config.hf_config.n_routed_experts == 256
fused_moe_state = get_fused_moe_state(ep_size, with_prefill,
is_deepseek_v3_r1)
forward_context.fused_moe_state = fused_moe_state
forward_context.in_profile_run = in_profile_run
# NOTE: This cannot be set using set_forward_context
# due to multiple warmups before actual capturing
forward_context.capturing = False
# set for sequence parallelism, 1000 is the batch size concurrency threshold for enabling the flashcomm_v1 or sequence_parallelism feature.
# Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold,
# the performance benefits can be maximized. Conversely, if the concurrency is below the threshold,
# the performance may degrade due to the switching of communication methods.
mmrs_fusion = True
if is_moe_model(vllm_config):
sp_enabled = enable_sp(vllm_config) and num_tokens is not None
mmrs_fusion = False
else:
sp_enabled = enable_sp(vllm_config) and \
num_tokens is not None and num_tokens > 1000
forward_context.mmrs_fusion = mmrs_fusion
forward_context.num_tokens = num_tokens
forward_context.sp_enabled = sp_enabled
#TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
forward_context.flashcomm_v2_enabled = flashcomm2_enable(
) and tp_world_size > 1 and num_tokens is not None
if (forward_context.sp_enabled
or forward_context.flashcomm_v2_enabled):
pad_size = (tp_world_size -
(num_tokens % tp_world_size)) % tp_world_size
forward_context.pad_size = pad_size
# set this for rope forward_oot using
forward_context.is_first_layer = True
# set layer_idx to enable optimization features that depend on this information.
# This is only applicable to models that contain these necessary attributes.
forward_context.layer_idx = None
if has_layer_idx(model_instance):
forward_context.layer_idx = model_instance.model.start_layer
# TODO(rjg-lyh): refactor mlp weight prefetch method
# set for mlp weight prefetch
prefetch_mlp_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \
envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \
forward_context.layer_idx is not None and \
num_tokens is not None and num_tokens < 500
if prefetch_mlp_enabled:
forward_context.prefetch_stream = prefetch_stream
forward_context.model_instance = model_instance
forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
forward_context.model_instance = model_instance
forward_context.weight_prefetch_method = weight_prefetch_method
[Feat] Support MTP to running in full graph mode (#3892) ### What this PR does / why we need it? Currently, the MTP model still runs in eager in full graph mode. This PR adapts the MTP with the full graph capture and execution. When the graph mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to improve the performance. The change in both disable_padded_drafter_batch is True and False case include: 1. Add _mtp_graph_params in acl_graph.py to isolate the data of main model and the data of MTP. 2. Padding some metadata in mla_v1.py when in fullgraph mode. 3. Fixed the essential data address that will be used in model.forward. 4. Adapted according to the aclgraph capture framwork: 1). Rebuild MTP model with ACLGraphWrapper. 2). Add common attn metadata when start capture in MTP dummy_run. 3). Add common attn metadata update in MTP. 4). Addapted data update when num_speculative_tokens > 1. 5. Add a patch of MTP to adapt vllm v0.11.0. Existing Issues: 1. When disable_padded_drafter_batch=True and running in FullGraph mode, the data of the first-round requests in MTP is abnormal. We need to identify the cause subsequently. 2. When disable_padded_drafter_batch=False and running in FullGraph mode, the acceptance rate of the second and third tokens will decrease (For example, if we set the num_speculative_tokens=3, the acceptance rate of first token is 90%, the second is only 50% lower than 60%, the third is only 20% lower than 30%). The reason is that the data processed after the model runs does not match. This is a problem from another PR. It works fine in eager and PIECEWISE mode, but has problem in FullGraph mode. Once we have a solution, we will submit a bugfix. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 --------- Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
forward_context.is_mtp_model = is_mtp_model
if num_tokens is None and attn_metadata is not None:
num_tokens = attn_metadata.num_actual_tokens
dp_world_size = get_dp_group().world_size
if dp_world_size > 1 and forward_context.dp_metadata is not None:
max_tokens_across_dp = \
forward_context.dp_metadata.max_tokens_across_dp_cpu.item()
if (forward_context.sp_enabled
or forward_context.flashcomm_v2_enabled):
padded_length = (max_tokens_across_dp + tp_world_size -
1) // tp_world_size * tp_world_size
pad_size = padded_length - num_tokens
forward_context.padded_length = padded_length
forward_context.pad_size = pad_size
else:
max_tokens_across_dp = num_tokens
forward_context.max_tokens_across_dp = max_tokens_across_dp
if num_tokens is not None:
if num_actual_tokens is None:
num_actual_tokens = num_tokens
# NOTE: token num which need to pad to when mc2
forward_context.padded_num_tokens = math.ceil(
max_tokens_across_dp / tp_world_size) * tp_world_size
if reserved_mc2_mask is not None:
mc2_mask = reserved_mc2_mask[:forward_context.
padded_num_tokens]
mc2_mask[:num_actual_tokens] = True
mc2_mask[num_actual_tokens:] = False
forward_context.mc2_mask = mc2_mask
try:
yield
finally:
pass