[Feature] support aclgraph for model runner v2 (#7110)

### What this PR does / why we need it?
This PR aims to support aclgraph for model runner v2, please see RFC
#5208. The PR contains these modifications:
- adapt to newest commit of vllm main branch.
- supply a unified interface of extra forward context for both model
runner v1 and model runner v2.
- implement graph mode for main model. 

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?

- vLLM version: v0.16.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
Ronald
2026-03-13 09:11:46 +08:00
committed by GitHub
parent 1f71da80eb
commit c980e68d40
52 changed files with 840 additions and 309 deletions

View File

@@ -13,7 +13,7 @@ from vllm.distributed import (
from vllm.forward_context import get_forward_context
from vllm.utils.torch_utils import direct_register_custom_op
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
from vllm_ascend.ops.rotary_embedding import rope_forward_oot
from vllm_ascend.ops.triton.muls_add import muls_add_triton
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
@@ -22,12 +22,12 @@ from vllm_ascend.utils import npu_stream_switch, prefetch_stream
def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
try:
forward_context = get_forward_context()
get_forward_context()
except AssertionError:
return residual
if x.size(0) != residual.size(0):
pad_size = forward_context.pad_size
pad_size = _EXTRA_CTX.pad_size
if pad_size > 0:
residual = F.pad(residual, (0, 0, 0, pad_size))
tp_size = get_tensor_model_parallel_world_size()
@@ -43,12 +43,12 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, label: bool, is_ep_c
except AssertionError:
return x
flash_comm_v1_enabled = forward_context.flash_comm_v1_enabled
flash_comm_v1_enabled = _EXTRA_CTX.flash_comm_v1_enabled
if flash_comm_v1_enabled and label:
dp_metadata = forward_context.dp_metadata
if dp_metadata is None or not is_ep_comm:
x = tensor_model_parallel_all_gather(x, 0)
pad_size = forward_context.pad_size
pad_size = _EXTRA_CTX.pad_size
if pad_size > 0:
x = x[:-pad_size]
else:
@@ -57,7 +57,7 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, label: bool, is_ep_c
num_tokens_across_dp_cpu = dp_metadata.num_tokens_across_dp_cpu
result = torch.empty((num_tokens_across_dp_cpu.sum(), *x.shape[1:]), device=x.device, dtype=x.dtype)
dp_size = get_dp_group().world_size
x = x.view(dp_size, forward_context.padded_length, *x.shape[1:])
x = x.view(dp_size, _EXTRA_CTX.padded_length, *x.shape[1:])
offset = 0
for idx in range(dp_size):
num_tokens_dp = num_tokens_across_dp_cpu[idx]
@@ -79,7 +79,7 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor
dp_metadata = forward_context.dp_metadata
if dp_metadata is None or not is_ep_comm:
pad_size = forward_context.pad_size
pad_size = _EXTRA_CTX.pad_size
if pad_size > 0:
x = F.pad(x, (0, 0, 0, pad_size))
return tensor_model_parallel_reduce_scatter(x, 0)
@@ -87,7 +87,7 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor
# padding
dp_size = get_dp_group().world_size
num_tokens_across_dp_cpu = get_forward_context().dp_metadata.num_tokens_across_dp_cpu
padded_x = torch.empty((dp_size, forward_context.padded_length, *x.shape[1:]), device=x.device, dtype=x.dtype)
padded_x = torch.empty((dp_size, _EXTRA_CTX.padded_length, *x.shape[1:]), device=x.device, dtype=x.dtype)
offset = 0
for idx in range(dp_size):
num_tokens_dp = num_tokens_across_dp_cpu[idx]
@@ -98,7 +98,7 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor
def _maybe_all_gather_and_maybe_unpad_fake(x: torch.Tensor, label: bool, is_ep_comm: bool = False) -> torch.Tensor:
if get_forward_context().flash_comm_v1_enabled and label:
if _EXTRA_CTX.flash_comm_v1_enabled and label:
return torch.empty(
(x.shape[0] * get_tensor_model_parallel_world_size(), *x.shape[1:]), device=x.device, dtype=x.dtype
)
@@ -107,7 +107,7 @@ def _maybe_all_gather_and_maybe_unpad_fake(x: torch.Tensor, label: bool, is_ep_c
def _maybe_pad_and_reduce_fake(x: torch.Tensor, is_ep_comm: bool = False) -> torch.Tensor:
if get_forward_context().flash_comm_v1_enabled:
if _EXTRA_CTX.flash_comm_v1_enabled:
return torch.empty(
(x.shape[0] // get_tensor_model_parallel_world_size(), *x.shape[1:]), device=x.device, dtype=x.dtype
)
@@ -138,11 +138,10 @@ def _prefetch_postprocess_impl_fake(stop_flag: torch.Tensor) -> None:
def _maybe_all_reduce_tensor_model_parallel_impl(final_hidden_states: torch.Tensor) -> torch.Tensor:
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
moe_comm_type = _EXTRA_CTX.moe_comm_type
if (
moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2}
or forward_context.flash_comm_v1_enabled
or _EXTRA_CTX.flash_comm_v1_enabled
):
return final_hidden_states
else:
@@ -163,7 +162,7 @@ def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor, layer_name: str)
forward_context = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
num_tokens = input_parallel.size(0)
if forward_context.flash_comm_v1_enabled:
if _EXTRA_CTX.flash_comm_v1_enabled:
num_tokens = num_tokens // self.tp_size
output = torch.empty(
size=(num_tokens, self.output_size_per_partition), device=input_parallel.device, dtype=input_parallel.dtype