[Feature] implenment set_additional_forward_context for model runner v2 (#5720)
### What this PR does / why we need it?
we implement set_additional_forward_context in platform, it's necessary
to reuse code of gpu in model runner v2 by inheriting method in gpu
model runer v2. please see model runner v2's plan #5208
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
@@ -15,11 +15,13 @@
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import math
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import torch
|
||||
import vllm.envs as envs_vllm
|
||||
from vllm.logger import logger
|
||||
from vllm.platforms import Platform, PlatformEnum
|
||||
|
||||
@@ -27,14 +29,14 @@ from vllm.platforms import Platform, PlatformEnum
|
||||
os.environ["VLLM_DISABLE_SHARED_EXPERTS_STREAM"] = "1"
|
||||
|
||||
from vllm_ascend.ascend_config import init_ascend_config
|
||||
from vllm_ascend.utils import refresh_block_size
|
||||
|
||||
# isort: off
|
||||
from vllm_ascend.utils import (
|
||||
ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD,
|
||||
COMPILATION_PASS_KEY, AscendDeviceType, enable_sp, get_ascend_device_type,
|
||||
is_vl_model, update_aclgraph_sizes, update_cudagraph_capture_sizes,
|
||||
update_default_aclgraph_sizes, check_kv_extra_config)
|
||||
ASCEND_QUANTIZATION_METHOD, COMPILATION_PASS_KEY,
|
||||
COMPRESSED_TENSORS_METHOD, AscendDeviceType, check_kv_extra_config,
|
||||
enable_sp, flashcomm2_enable, get_ascend_device_type, is_moe_model,
|
||||
is_vl_model, refresh_block_size, update_aclgraph_sizes,
|
||||
update_cudagraph_capture_sizes, update_default_aclgraph_sizes)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
@@ -330,7 +332,8 @@ class NPUPlatform(Platform):
|
||||
compilation_config.custom_ops = ["all"]
|
||||
|
||||
if ascend_config.recompute_scheduler_enable:
|
||||
from vllm_ascend.core.recompute_scheduler import RecomputeSchedulerConfig
|
||||
from vllm_ascend.core.recompute_scheduler import \
|
||||
RecomputeSchedulerConfig
|
||||
recompute_scheduler_config = RecomputeSchedulerConfig.initialize_from_config(
|
||||
vllm_config)
|
||||
vllm_config.scheduler_config = recompute_scheduler_config
|
||||
@@ -435,3 +438,134 @@ class NPUPlatform(Platform):
|
||||
@classmethod
|
||||
def support_static_graph_mode(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def set_additional_forward_context(
|
||||
cls,
|
||||
attn_metadata: dict[str, Any],
|
||||
vllm_config: VllmConfig,
|
||||
dp_metadata,
|
||||
virtual_engine: int = 0,
|
||||
num_tokens: int | None = None,
|
||||
num_tokens_across_dp: torch.Tensor | None = None,
|
||||
cudagraph_runtime_mode=None,
|
||||
batch_descriptor=None,
|
||||
ubatch_slices=None,
|
||||
) -> dict[str, Any]:
|
||||
"""set additional forward context for ascend npus.
|
||||
|
||||
Args:
|
||||
attn_metadata (dict[str, Any]): attention metadata for all layers.
|
||||
vllm_config (VllmConfig): configuration of vllm.
|
||||
dp_metadata (DpMetada): metadata for data parallelism.
|
||||
lack of typehint because of circular import.
|
||||
virtual_engine (int, optional): index of virtual engine. Defaults to 0.
|
||||
num_tokens (int | None, optional): number of tokens. Defaults to None.
|
||||
num_tokens_across_dp (torch.Tensor | None, optional): number of tokens
|
||||
across data parallelism.Defaults to None.
|
||||
cudagraph_runtime_mode (CUDAGraphMode, optional): mode of cudagraph runtime.
|
||||
Defaults to None.lack of typehint because of circular import.
|
||||
batch_descriptor (BatchDescriptor, optional): descriptor of batch.
|
||||
Defaults to None.
|
||||
ubatch_slices (UBatchSlices, optional): slice info for dual batch.
|
||||
Defaults to None. lack of typehint because of circular import
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: _description_
|
||||
"""
|
||||
# NOTE(Ronald1995): avoid circular import.
|
||||
from vllm_ascend.ascend_forward_context import (get_mc2_mask,
|
||||
select_moe_comm_method)
|
||||
from vllm_ascend.ops.fused_moe.moe_comm_method import get_moe_comm_method
|
||||
from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size
|
||||
# NOTE(Ronald1995): avoid circular import, cudagraph_runtime_mode is
|
||||
# CUDAGraphMode.NONE in vllm, but we can't set CUDAGraphMode.NONE in
|
||||
# argument default value, so we set it to None first, then set it to
|
||||
# CUDAGraphMode.NONE here.
|
||||
from vllm.config import CUDAGraphMode
|
||||
if cudagraph_runtime_mode is None:
|
||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||
# TODO(Ronald1995): model runner v1 still use ascend_forward_context,
|
||||
# when v1's forward context is refactored, we can remove this branch.
|
||||
# Currently, model runner v2 use the new forward context.
|
||||
# compared to v1, v2's forward context lacks some fields, such as:
|
||||
# in_profile_run, is_first_layer, prefetch_mlp_gate_up_proj,
|
||||
# prefetch_mlp_gate_down_proj, prefetch_mlp_enabled, model_instance,
|
||||
# is_draft_model.
|
||||
if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER:
|
||||
return {}
|
||||
|
||||
num_actual_tokens = list(attn_metadata.values())[0].num_actual_tokens
|
||||
if num_tokens is None:
|
||||
num_tokens = num_actual_tokens
|
||||
|
||||
moe_comm_type = select_moe_comm_method(
|
||||
num_tokens,
|
||||
vllm_config,
|
||||
# is_draft_model will be removed later, so we set it to False temporarily.
|
||||
is_draft_model=False,
|
||||
)
|
||||
moe_comm_method = get_moe_comm_method(moe_comm_type)
|
||||
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
# NOTE: This cannot be set using set_forward_context
|
||||
# due to multiple warmups before actual capturing.
|
||||
capturing = False
|
||||
|
||||
# set for sequence parallelism, 1000 is the batch size concurrency
|
||||
# threshold for enabling the flashcomm_v1 or sequence_parallelism feature.
|
||||
# Currently, it is an empirical value. In normal scenarios,
|
||||
# if the concurrency exceeds this threshold,
|
||||
# the performance benefits can be maximized. Conversely,
|
||||
# if the concurrency is below the threshold,
|
||||
# the performance may degrade due to the switching of
|
||||
# communication methods.
|
||||
mmrs_fusion = True
|
||||
if is_moe_model(vllm_config):
|
||||
sp_enabled = enable_sp(vllm_config) and num_tokens is not None
|
||||
mmrs_fusion = False
|
||||
else:
|
||||
sp_enabled = enable_sp(vllm_config) and \
|
||||
num_tokens is not None and num_tokens > 1000
|
||||
|
||||
# TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
|
||||
flashcomm_v2_enabled = flashcomm2_enable(
|
||||
) and tp_world_size > 1 and num_tokens is not None
|
||||
pad_size = 0
|
||||
if (sp_enabled or flashcomm_v2_enabled):
|
||||
pad_size = (tp_world_size -
|
||||
(num_tokens % tp_world_size)) % tp_world_size
|
||||
|
||||
dp_world_size = get_dp_group().world_size
|
||||
if dp_world_size > 1 and dp_metadata is not None:
|
||||
max_tokens_across_dp = dp_metadata.max_tokens_across_dp_cpu.item()
|
||||
if (sp_enabled or flashcomm_v2_enabled):
|
||||
padded_length = (max_tokens_across_dp + tp_world_size -
|
||||
1) // tp_world_size * tp_world_size
|
||||
pad_size = padded_length - num_tokens
|
||||
else:
|
||||
max_tokens_across_dp = num_tokens
|
||||
|
||||
if num_tokens is not None:
|
||||
# NOTE: token num which need to pad to when mc2
|
||||
padded_num_tokens = math.ceil(
|
||||
max_tokens_across_dp / tp_world_size) * tp_world_size
|
||||
reserved_mc2_mask = get_mc2_mask()
|
||||
if reserved_mc2_mask is not None:
|
||||
mc2_mask = reserved_mc2_mask[:padded_num_tokens]
|
||||
mc2_mask[:num_actual_tokens] = True
|
||||
mc2_mask[num_actual_tokens:] = False
|
||||
return {
|
||||
"moe_comm_type": moe_comm_type,
|
||||
"moe_comm_method": moe_comm_method,
|
||||
"capturing": capturing,
|
||||
"mmrs_fusion": mmrs_fusion,
|
||||
"num_tokens": num_tokens,
|
||||
"sp_enabled": sp_enabled,
|
||||
"flashcomm_v2_enabled": flashcomm_v2_enabled,
|
||||
"pad_size": pad_size,
|
||||
"padded_length": padded_length,
|
||||
"max_tokens_across_dp": max_tokens_across_dp,
|
||||
"mc2_mask": mc2_mask,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user