[Feature] support aclgraph for model runner v2 (#7110)

### What this PR does / why we need it?
This PR aims to support aclgraph for model runner v2, please see RFC
#5208. The PR contains these modifications:
- adapt to newest commit of vllm main branch.
- supply a unified interface of extra forward context for both model
runner v1 and model runner v2.
- implement graph mode for main model. 

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?

- vLLM version: v0.16.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
Ronald
2026-03-13 09:11:46 +08:00
committed by GitHub
parent 1f71da80eb
commit c980e68d40
52 changed files with 840 additions and 309 deletions

View File

@@ -37,7 +37,7 @@ if not vllm_version_is("0.16.0"):
from vllm.model_executor.layers.fused_moe.runner.default_moe_runner import DefaultMoERunner # type: ignore
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.eplb.core.eplb_utils import init_eplb_config
from vllm_ascend.flash_common3_context import get_flash_common3_context, set_flash_common3_context
@@ -148,7 +148,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
random_matrix = torch.rand(topk_ids.size(0), global_num_experts, device=topk_ids.device)
topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype)
moe_comm_method = get_forward_context().moe_comm_method
moe_comm_method = _EXTRA_CTX.moe_comm_method
final_hidden_states = moe_comm_method.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
@@ -401,12 +401,13 @@ class AscendFusedMoE(FusedMoE):
# When static kernels are enabled, the forward pass runs twice (compilation + capture),
# causing moe_layer_index to overflow. Wrap the index to prevent out-of-bounds errors.
if self.enable_npugraph_ex_static_kernel:
forward_context.moe_layer_index = forward_context.moe_layer_index % (len(forward_context.all_moe_layers))
moe_layer_index = forward_context.moe_layer_index % (len(forward_context.all_moe_layers))
forward_context.moe_layer_index = moe_layer_index
# Load balancing for token distribution among experts in dummy_run
# TODO: The community only considers load balancing when DP > 1.
# This approach may overlook some extreme scenarios.
enable_force_load_balance = forward_context.in_profile_run
enable_force_load_balance = _EXTRA_CTX.in_profile_run
forward_context = get_forward_context()
if self.multistream_overlap_gate:
@@ -419,7 +420,7 @@ class AscendFusedMoE(FusedMoE):
assert fc3_context.shared_experts is not None
shared_out = fc3_context.shared_experts(hidden_states)
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
moe_comm_type = forward_context.moe_comm_type
moe_comm_type = _EXTRA_CTX.moe_comm_type
if (
moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2}
and not shared_expert_dp_enabled()
@@ -442,16 +443,16 @@ class AscendFusedMoE(FusedMoE):
global_num_experts=self.global_num_experts,
)
if isinstance(forward_context.moe_comm_method, AllGatherCommImpl):
if isinstance(_EXTRA_CTX.moe_comm_method, AllGatherCommImpl):
topk_weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(topk_weights, True, True)
topk_ids = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(topk_ids, True, True)
set_flash_common3_context(topk_weights=topk_weights, topk_ids=topk_ids)
hidden_states, router_logits, mc2_mask, context_metadata = forward_context.moe_comm_method.prepare(
hidden_states, router_logits, mc2_mask, context_metadata = _EXTRA_CTX.moe_comm_method.prepare(
hidden_states=hidden_states,
router_logits=router_logits,
replace_allreduce=forward_context.flash_comm_v1_enabled,
replace_allreduce=_EXTRA_CTX.flash_comm_v1_enabled,
enable_shared_expert_dp=self.enable_shared_expert_dp,
quant_type=self.quant_type,
)
@@ -509,7 +510,7 @@ class AscendFusedMoE(FusedMoE):
self.load_counter.add_(1)
else:
self.moe_load.add_(local_load)
routed_out = forward_context.moe_comm_method.finalize(
routed_out = _EXTRA_CTX.moe_comm_method.finalize(
hidden_states=fused_experts_results.routed_out,
reduce_results=self.reduce_results,
context_metadata=context_metadata,
@@ -670,8 +671,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
# NOTE: This is exactly the opposite of
# `maybe_all_reduce_tensor_model_parallel`
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
moe_comm_type = _EXTRA_CTX.moe_comm_type
if (
moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2}
and not shared_expert_dp_enabled()

View File

@@ -19,11 +19,10 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
import torch
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
from vllm_ascend.ops.fused_moe.prepare_finalize import (
PrepareAndFinalize,
@@ -135,7 +134,7 @@ class MoECommMethod(ABC):
# Check constraints
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16, torch.int8]
moe_comm_method = get_forward_context().moe_comm_method
moe_comm_method = _EXTRA_CTX.moe_comm_method
assert moe_comm_method is not None, "Missing communication context"
before_dispatch_evt = torch.npu.current_stream().record_event()

View File

@@ -18,10 +18,9 @@
import torch
import torch_npu
from torch.nn.functional import pad
from vllm.forward_context import get_forward_context
from vllm.triton_utils import HAS_TRITON
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
from vllm_ascend.device.device_op import DeviceOperator
from vllm_ascend.device.mxfp_compat import (
ensure_mxfp8_moe_available,
@@ -147,7 +146,7 @@ def quant_apply_mlp(
weight_prefetch_method = get_weight_prefetch_method()
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(hidden_states)
is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2
is_mc2 = _EXTRA_CTX.moe_comm_type == MoECommType.MC2
if w1_scale_bias is None and w1_offset is None and is_mc2:
if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb) and not use_mxfp_quant:
# gmm1: gate_up_proj & act_fn: swiglu

View File

@@ -26,10 +26,10 @@ from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl
from vllm_ascend.quantization.methods.base import QuantType
from vllm_ascend.utils import enable_sp, npu_stream_switch, prefill_context_parallel_enable
@@ -242,8 +242,7 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
"""
self.replace_allreduce = replace_allreduce
self.enable_shared_expert_dp = enable_shared_expert_dp
forward_context = get_forward_context()
mc2_mask = forward_context.mc2_mask
mc2_mask = _EXTRA_CTX.mc2_mask
if self.tp_size > 1:
# Also slice mc2_mask
split_mc2_mask = torch.tensor_split(mc2_mask, self.tp_size, dim=0)
@@ -252,7 +251,7 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
padded_hidden_states_shape = hidden_states.shape
if not self.replace_allreduce:
self.num_tokens, _ = hidden_states.shape
target_pad_length = forward_context.padded_num_tokens
target_pad_length = _EXTRA_CTX.padded_num_tokens
pad_size = target_pad_length - self.num_tokens
# Pad if necessary (unless shared expert DP is enabled)
@@ -367,8 +366,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
"""
self.enable_shared_expert_dp = enable_shared_expert_dp
if self.moe_config.dp_size > 1:
forward_context = get_forward_context()
max_tokens_across_dp = forward_context.max_tokens_across_dp
max_tokens_across_dp = _EXTRA_CTX.max_tokens_across_dp
self.num_tokens = hidden_states.shape[0]
pad_size = max_tokens_across_dp - self.num_tokens
@@ -381,8 +379,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
router_logits = self.moe_config.dp_group.all_gather(router_logits, 0)
if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1:
forward_context = get_forward_context()
max_tokens_across_pcp = forward_context.max_tokens_across_pcp
max_tokens_across_pcp = _EXTRA_CTX.max_tokens_across_pcp
self.num_tokens_pcp = hidden_states.shape[0]
pad_size = max_tokens_across_pcp - self.num_tokens_pcp