diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index c93aa78b..eee64fce 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -15,11 +15,13 @@ # This file is a part of the vllm-ascend project. # +import math import os -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Optional from uuid import uuid4 import torch +import vllm.envs as envs_vllm from vllm.logger import logger from vllm.platforms import Platform, PlatformEnum @@ -27,14 +29,14 @@ from vllm.platforms import Platform, PlatformEnum os.environ["VLLM_DISABLE_SHARED_EXPERTS_STREAM"] = "1" from vllm_ascend.ascend_config import init_ascend_config -from vllm_ascend.utils import refresh_block_size # isort: off from vllm_ascend.utils import ( - ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD, - COMPILATION_PASS_KEY, AscendDeviceType, enable_sp, get_ascend_device_type, - is_vl_model, update_aclgraph_sizes, update_cudagraph_capture_sizes, - update_default_aclgraph_sizes, check_kv_extra_config) + ASCEND_QUANTIZATION_METHOD, COMPILATION_PASS_KEY, + COMPRESSED_TENSORS_METHOD, AscendDeviceType, check_kv_extra_config, + enable_sp, flashcomm2_enable, get_ascend_device_type, is_moe_model, + is_vl_model, refresh_block_size, update_aclgraph_sizes, + update_cudagraph_capture_sizes, update_default_aclgraph_sizes) if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig @@ -330,7 +332,8 @@ class NPUPlatform(Platform): compilation_config.custom_ops = ["all"] if ascend_config.recompute_scheduler_enable: - from vllm_ascend.core.recompute_scheduler import RecomputeSchedulerConfig + from vllm_ascend.core.recompute_scheduler import \ + RecomputeSchedulerConfig recompute_scheduler_config = RecomputeSchedulerConfig.initialize_from_config( vllm_config) vllm_config.scheduler_config = recompute_scheduler_config @@ -435,3 +438,134 @@ class NPUPlatform(Platform): @classmethod def support_static_graph_mode(cls) -> bool: return True + + @classmethod + def set_additional_forward_context( + cls, + attn_metadata: dict[str, Any], + vllm_config: VllmConfig, + dp_metadata, + virtual_engine: int = 0, + num_tokens: int | None = None, + num_tokens_across_dp: torch.Tensor | None = None, + cudagraph_runtime_mode=None, + batch_descriptor=None, + ubatch_slices=None, + ) -> dict[str, Any]: + """set additional forward context for ascend npus. + + Args: + attn_metadata (dict[str, Any]): attention metadata for all layers. + vllm_config (VllmConfig): configuration of vllm. + dp_metadata (DpMetada): metadata for data parallelism. + lack of typehint because of circular import. + virtual_engine (int, optional): index of virtual engine. Defaults to 0. + num_tokens (int | None, optional): number of tokens. Defaults to None. + num_tokens_across_dp (torch.Tensor | None, optional): number of tokens + across data parallelism.Defaults to None. + cudagraph_runtime_mode (CUDAGraphMode, optional): mode of cudagraph runtime. + Defaults to None.lack of typehint because of circular import. + batch_descriptor (BatchDescriptor, optional): descriptor of batch. + Defaults to None. + ubatch_slices (UBatchSlices, optional): slice info for dual batch. + Defaults to None. lack of typehint because of circular import + + Returns: + dict[str, Any]: _description_ + """ + # NOTE(Ronald1995): avoid circular import. + from vllm_ascend.ascend_forward_context import (get_mc2_mask, + select_moe_comm_method) + from vllm_ascend.ops.fused_moe.moe_comm_method import get_moe_comm_method + from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size + # NOTE(Ronald1995): avoid circular import, cudagraph_runtime_mode is + # CUDAGraphMode.NONE in vllm, but we can't set CUDAGraphMode.NONE in + # argument default value, so we set it to None first, then set it to + # CUDAGraphMode.NONE here. + from vllm.config import CUDAGraphMode + if cudagraph_runtime_mode is None: + cudagraph_runtime_mode = CUDAGraphMode.NONE + # TODO(Ronald1995): model runner v1 still use ascend_forward_context, + # when v1's forward context is refactored, we can remove this branch. + # Currently, model runner v2 use the new forward context. + # compared to v1, v2's forward context lacks some fields, such as: + # in_profile_run, is_first_layer, prefetch_mlp_gate_up_proj, + # prefetch_mlp_gate_down_proj, prefetch_mlp_enabled, model_instance, + # is_draft_model. + if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER: + return {} + + num_actual_tokens = list(attn_metadata.values())[0].num_actual_tokens + if num_tokens is None: + num_tokens = num_actual_tokens + + moe_comm_type = select_moe_comm_method( + num_tokens, + vllm_config, + # is_draft_model will be removed later, so we set it to False temporarily. + is_draft_model=False, + ) + moe_comm_method = get_moe_comm_method(moe_comm_type) + + tp_world_size = get_tensor_model_parallel_world_size() + + # NOTE: This cannot be set using set_forward_context + # due to multiple warmups before actual capturing. + capturing = False + + # set for sequence parallelism, 1000 is the batch size concurrency + # threshold for enabling the flashcomm_v1 or sequence_parallelism feature. + # Currently, it is an empirical value. In normal scenarios, + # if the concurrency exceeds this threshold, + # the performance benefits can be maximized. Conversely, + # if the concurrency is below the threshold, + # the performance may degrade due to the switching of + # communication methods. + mmrs_fusion = True + if is_moe_model(vllm_config): + sp_enabled = enable_sp(vllm_config) and num_tokens is not None + mmrs_fusion = False + else: + sp_enabled = enable_sp(vllm_config) and \ + num_tokens is not None and num_tokens > 1000 + + # TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2 + flashcomm_v2_enabled = flashcomm2_enable( + ) and tp_world_size > 1 and num_tokens is not None + pad_size = 0 + if (sp_enabled or flashcomm_v2_enabled): + pad_size = (tp_world_size - + (num_tokens % tp_world_size)) % tp_world_size + + dp_world_size = get_dp_group().world_size + if dp_world_size > 1 and dp_metadata is not None: + max_tokens_across_dp = dp_metadata.max_tokens_across_dp_cpu.item() + if (sp_enabled or flashcomm_v2_enabled): + padded_length = (max_tokens_across_dp + tp_world_size - + 1) // tp_world_size * tp_world_size + pad_size = padded_length - num_tokens + else: + max_tokens_across_dp = num_tokens + + if num_tokens is not None: + # NOTE: token num which need to pad to when mc2 + padded_num_tokens = math.ceil( + max_tokens_across_dp / tp_world_size) * tp_world_size + reserved_mc2_mask = get_mc2_mask() + if reserved_mc2_mask is not None: + mc2_mask = reserved_mc2_mask[:padded_num_tokens] + mc2_mask[:num_actual_tokens] = True + mc2_mask[num_actual_tokens:] = False + return { + "moe_comm_type": moe_comm_type, + "moe_comm_method": moe_comm_method, + "capturing": capturing, + "mmrs_fusion": mmrs_fusion, + "num_tokens": num_tokens, + "sp_enabled": sp_enabled, + "flashcomm_v2_enabled": flashcomm_v2_enabled, + "pad_size": pad_size, + "padded_length": padded_length, + "max_tokens_across_dp": max_tokens_across_dp, + "mc2_mask": mc2_mask, + }