[Feature] implenment set_additional_forward_context for model runner v2 (#5720)

### What this PR does / why we need it?
we implement set_additional_forward_context in platform, it's necessary
to reuse code of gpu in model runner v2 by inheriting method in gpu
model runer v2. please see model runner v2's plan #5208

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

---------

Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
Ronald
2026-01-15 09:18:28 +08:00
committed by GitHub
parent 4811ba62e0
commit 7078dff691

View File

@@ -15,11 +15,13 @@
# This file is a part of the vllm-ascend project.
#
import math
import os
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, Optional
from uuid import uuid4
import torch
import vllm.envs as envs_vllm
from vllm.logger import logger
from vllm.platforms import Platform, PlatformEnum
@@ -27,14 +29,14 @@ from vllm.platforms import Platform, PlatformEnum
os.environ["VLLM_DISABLE_SHARED_EXPERTS_STREAM"] = "1"
from vllm_ascend.ascend_config import init_ascend_config
from vllm_ascend.utils import refresh_block_size
# isort: off
from vllm_ascend.utils import (
ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD,
COMPILATION_PASS_KEY, AscendDeviceType, enable_sp, get_ascend_device_type,
is_vl_model, update_aclgraph_sizes, update_cudagraph_capture_sizes,
update_default_aclgraph_sizes, check_kv_extra_config)
ASCEND_QUANTIZATION_METHOD, COMPILATION_PASS_KEY,
COMPRESSED_TENSORS_METHOD, AscendDeviceType, check_kv_extra_config,
enable_sp, flashcomm2_enable, get_ascend_device_type, is_moe_model,
is_vl_model, refresh_block_size, update_aclgraph_sizes,
update_cudagraph_capture_sizes, update_default_aclgraph_sizes)
if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
@@ -330,7 +332,8 @@ class NPUPlatform(Platform):
compilation_config.custom_ops = ["all"]
if ascend_config.recompute_scheduler_enable:
from vllm_ascend.core.recompute_scheduler import RecomputeSchedulerConfig
from vllm_ascend.core.recompute_scheduler import \
RecomputeSchedulerConfig
recompute_scheduler_config = RecomputeSchedulerConfig.initialize_from_config(
vllm_config)
vllm_config.scheduler_config = recompute_scheduler_config
@@ -435,3 +438,134 @@ class NPUPlatform(Platform):
@classmethod
def support_static_graph_mode(cls) -> bool:
return True
@classmethod
def set_additional_forward_context(
cls,
attn_metadata: dict[str, Any],
vllm_config: VllmConfig,
dp_metadata,
virtual_engine: int = 0,
num_tokens: int | None = None,
num_tokens_across_dp: torch.Tensor | None = None,
cudagraph_runtime_mode=None,
batch_descriptor=None,
ubatch_slices=None,
) -> dict[str, Any]:
"""set additional forward context for ascend npus.
Args:
attn_metadata (dict[str, Any]): attention metadata for all layers.
vllm_config (VllmConfig): configuration of vllm.
dp_metadata (DpMetada): metadata for data parallelism.
lack of typehint because of circular import.
virtual_engine (int, optional): index of virtual engine. Defaults to 0.
num_tokens (int | None, optional): number of tokens. Defaults to None.
num_tokens_across_dp (torch.Tensor | None, optional): number of tokens
across data parallelism.Defaults to None.
cudagraph_runtime_mode (CUDAGraphMode, optional): mode of cudagraph runtime.
Defaults to None.lack of typehint because of circular import.
batch_descriptor (BatchDescriptor, optional): descriptor of batch.
Defaults to None.
ubatch_slices (UBatchSlices, optional): slice info for dual batch.
Defaults to None. lack of typehint because of circular import
Returns:
dict[str, Any]: _description_
"""
# NOTE(Ronald1995): avoid circular import.
from vllm_ascend.ascend_forward_context import (get_mc2_mask,
select_moe_comm_method)
from vllm_ascend.ops.fused_moe.moe_comm_method import get_moe_comm_method
from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size
# NOTE(Ronald1995): avoid circular import, cudagraph_runtime_mode is
# CUDAGraphMode.NONE in vllm, but we can't set CUDAGraphMode.NONE in
# argument default value, so we set it to None first, then set it to
# CUDAGraphMode.NONE here.
from vllm.config import CUDAGraphMode
if cudagraph_runtime_mode is None:
cudagraph_runtime_mode = CUDAGraphMode.NONE
# TODO(Ronald1995): model runner v1 still use ascend_forward_context,
# when v1's forward context is refactored, we can remove this branch.
# Currently, model runner v2 use the new forward context.
# compared to v1, v2's forward context lacks some fields, such as:
# in_profile_run, is_first_layer, prefetch_mlp_gate_up_proj,
# prefetch_mlp_gate_down_proj, prefetch_mlp_enabled, model_instance,
# is_draft_model.
if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER:
return {}
num_actual_tokens = list(attn_metadata.values())[0].num_actual_tokens
if num_tokens is None:
num_tokens = num_actual_tokens
moe_comm_type = select_moe_comm_method(
num_tokens,
vllm_config,
# is_draft_model will be removed later, so we set it to False temporarily.
is_draft_model=False,
)
moe_comm_method = get_moe_comm_method(moe_comm_type)
tp_world_size = get_tensor_model_parallel_world_size()
# NOTE: This cannot be set using set_forward_context
# due to multiple warmups before actual capturing.
capturing = False
# set for sequence parallelism, 1000 is the batch size concurrency
# threshold for enabling the flashcomm_v1 or sequence_parallelism feature.
# Currently, it is an empirical value. In normal scenarios,
# if the concurrency exceeds this threshold,
# the performance benefits can be maximized. Conversely,
# if the concurrency is below the threshold,
# the performance may degrade due to the switching of
# communication methods.
mmrs_fusion = True
if is_moe_model(vllm_config):
sp_enabled = enable_sp(vllm_config) and num_tokens is not None
mmrs_fusion = False
else:
sp_enabled = enable_sp(vllm_config) and \
num_tokens is not None and num_tokens > 1000
# TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
flashcomm_v2_enabled = flashcomm2_enable(
) and tp_world_size > 1 and num_tokens is not None
pad_size = 0
if (sp_enabled or flashcomm_v2_enabled):
pad_size = (tp_world_size -
(num_tokens % tp_world_size)) % tp_world_size
dp_world_size = get_dp_group().world_size
if dp_world_size > 1 and dp_metadata is not None:
max_tokens_across_dp = dp_metadata.max_tokens_across_dp_cpu.item()
if (sp_enabled or flashcomm_v2_enabled):
padded_length = (max_tokens_across_dp + tp_world_size -
1) // tp_world_size * tp_world_size
pad_size = padded_length - num_tokens
else:
max_tokens_across_dp = num_tokens
if num_tokens is not None:
# NOTE: token num which need to pad to when mc2
padded_num_tokens = math.ceil(
max_tokens_across_dp / tp_world_size) * tp_world_size
reserved_mc2_mask = get_mc2_mask()
if reserved_mc2_mask is not None:
mc2_mask = reserved_mc2_mask[:padded_num_tokens]
mc2_mask[:num_actual_tokens] = True
mc2_mask[num_actual_tokens:] = False
return {
"moe_comm_type": moe_comm_type,
"moe_comm_method": moe_comm_method,
"capturing": capturing,
"mmrs_fusion": mmrs_fusion,
"num_tokens": num_tokens,
"sp_enabled": sp_enabled,
"flashcomm_v2_enabled": flashcomm_v2_enabled,
"pad_size": pad_size,
"padded_length": padded_length,
"max_tokens_across_dp": max_tokens_across_dp,
"mc2_mask": mc2_mask,
}