[Feature] support aclgraph for model runner v2 (#7110)

### What this PR does / why we need it?
This PR aims to support aclgraph for model runner v2, please see RFC
#5208. The PR contains these modifications:
- adapt to newest commit of vllm main branch.
- supply a unified interface of extra forward context for both model
runner v1 and model runner v2.
- implement graph mode for main model. 

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?

- vLLM version: v0.16.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
Ronald
2026-03-13 09:11:46 +08:00
committed by GitHub
parent 1f71da80eb
commit c980e68d40
52 changed files with 840 additions and 309 deletions

View File

@@ -4,6 +4,7 @@ from enum import Enum
from typing import Any
import torch
import vllm.envs as envs_vllm
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed import get_dp_group, get_ep_group, get_tensor_model_parallel_world_size
from vllm.forward_context import BatchDescriptor, get_forward_context, set_forward_context
@@ -270,3 +271,61 @@ def select_moe_comm_method(num_tokens: int, vllm_config: VllmConfig, is_draft_mo
else:
raise ValueError(f"Unsupported soc_version: {soc_version}")
return moe_comm_type
class _ExtraForwardContextProxy:
"""Unified forward-context access for v1/v2 model runners."""
extra_attrs = (
"capturing",
"moe_comm_type",
"moe_comm_method",
"mmrs_fusion",
"num_tokens",
"flash_comm_v1_enabled",
"flashcomm_v2_enabled",
"pad_size",
"padded_length",
"num_tokens_across_dp",
"mc2_mask",
"is_draft_model",
"prefetch_mlp_gate_up_proj",
"prefetch_mlp_down_proj",
"model_instance",
"layer_idx",
"max_tokens_across_dp",
"max_tokens_across_pcp",
"num_accept_tokens",
"in_profile_run",
"padded_num_tokens",
)
def check_extra_attr(self, name: str):
if name not in self.extra_attrs:
raise AttributeError(
f"{name} is not extra forward context attribute, "
"please get/set it from vllm's _forward_context directly."
)
@staticmethod
def _ctx():
return get_forward_context()
def __getattr__(self, name: str) -> Any:
self.check_extra_attr(name)
ctx = self._ctx()
if envs_vllm.VLLM_USE_V2_MODEL_RUNNER:
return ctx.additional_kwargs[name]
return getattr(ctx, name)
def __setattr__(self, name: str, value: Any) -> None:
self.check_extra_attr(name)
ctx = self._ctx()
if envs_vllm.VLLM_USE_V2_MODEL_RUNNER:
ctx.additional_kwargs[name] = value
else:
setattr(ctx, name, value)
# usage: from vllm_ascend.ascend_forward_context import _EXTRA_CTX
_EXTRA_CTX = _ExtraForwardContextProxy()