From 52086394ae6605ce33186bdfb55a79c29bd2a519 Mon Sep 17 00:00:00 2001 From: SILONG ZENG <2609716663@qq.com> Date: Fri, 16 Jan 2026 20:57:46 +0800 Subject: [PATCH] [Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com> --- pyproject.toml | 20 +- vllm_ascend/__init__.py | 3 + vllm_ascend/ascend_config.py | 163 +++----- vllm_ascend/ascend_forward_context.py | 155 ++++---- vllm_ascend/compilation/acl_graph.py | 301 ++++++++------- vllm_ascend/compilation/compiler_interface.py | 43 +-- .../compilation/graph_fusion_pass_manager.py | 10 +- .../npugraph_ex_passes/add_rms_norm_quant.py | 195 +++++----- .../passes/norm_quant_fusion_pass.py | 226 +++++------ .../passes/qknorm_rope_fusion_pass.py | 218 ++++------- vllm_ascend/cpu_binding.py | 121 +++--- vllm_ascend/flash_common3_context.py | 25 +- vllm_ascend/meta_registration.py | 56 +-- vllm_ascend/platform.py | 207 +++++----- vllm_ascend/profiling_config.py | 37 +- vllm_ascend/utils.py | 356 ++++++++---------- 16 files changed, 996 insertions(+), 1140 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 47efe0ee..b20a7bcc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,25 @@ line-length = 120 # Folder to be modified exclude = [ "tests/**", - "vllm_ascend/**", + "vllm_ascend/_cann_ops_custom", + "vllm_ascend/attention", + "vllm_ascend/core", + "vllm_ascend/device", + "vllm_ascend/device_allocator", + "vllm_ascend/distributed", + "vllm_ascend/eplb", + "vllm_ascend/kv_offload", + "vllm_ascend/lora", + "vllm_ascend/model_loader", + "vllm_ascend/ops", + "vllm_ascend/patch", + "vllm_ascend/quantization", + "vllm_ascend/sample", + "vllm_ascend/spec_decode", + "vllm_ascend/worker", + "vllm_ascend/xlite", + "vllm_ascend/envs.py", + "vllm_ascend/batch_invariant.py", ] [tool.ruff.lint] diff --git a/vllm_ascend/__init__.py b/vllm_ascend/__init__.py index 997f97cf..26e22a5d 100644 --- a/vllm_ascend/__init__.py +++ b/vllm_ascend/__init__.py @@ -24,14 +24,17 @@ def register(): def register_connector(): from vllm_ascend.distributed.kv_transfer import register_connector + register_connector() def register_model_loader(): from .model_loader.netloader import register_netloader + register_netloader() def register_service_profiling(): from .profiling_config import generate_service_profiling_config + generate_service_profiling_config() diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 49a5409a..7b9737f2 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from vllm.logger import logger from vllm.triton_utils import HAS_TRITON @@ -32,18 +32,13 @@ class AscendConfig: additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {} xlite_graph_config = additional_config.get("xlite_graph_config", {}) - self.xlite_graph_config = XliteGraphConfig(xlite_graph_config, - vllm_config) + self.xlite_graph_config = XliteGraphConfig(xlite_graph_config, vllm_config) - ascend_compilation_config = additional_config.get( - "ascend_compilation_config", {}) - self.ascend_compilation_config = AscendCompilationConfig( - **ascend_compilation_config) + ascend_compilation_config = additional_config.get("ascend_compilation_config", {}) + self.ascend_compilation_config = AscendCompilationConfig(**ascend_compilation_config) - finegrained_tp_config = additional_config.get("finegrained_tp_config", - {}) - self.finegrained_tp_config = FinegrainedTPConfig( - finegrained_tp_config, vllm_config) + finegrained_tp_config = additional_config.get("finegrained_tp_config", {}) + self.finegrained_tp_config = FinegrainedTPConfig(finegrained_tp_config, vllm_config) eplb_config = additional_config.get("eplb_config", {}) self.eplb_config = EplbConfig(eplb_config) @@ -51,10 +46,8 @@ class AscendConfig: # Dump / PrecisionDebugger configuration self.dump_config_path = additional_config.get("dump_config_path", None) - weight_prefetch_config = additional_config.get( - "weight_prefetch_config", {}) - self.weight_prefetch_config = WeightPrefetchConfig( - weight_prefetch_config) + weight_prefetch_config = additional_config.get("weight_prefetch_config", {}) + self.weight_prefetch_config = WeightPrefetchConfig(weight_prefetch_config) self.layer_sharding = additional_config.get("layer_sharding", None) logger.info_once( f"Linear layer sharding enabled with config: {self.layer_sharding}. " @@ -62,30 +55,25 @@ class AscendConfig: "using it without these features may result in significant performance degradation." ) - self.enable_shared_expert_dp = additional_config.get( - "enable_shared_expert_dp", - False) and vllm_config.parallel_config.enable_expert_parallel + self.enable_shared_expert_dp = ( + additional_config.get("enable_shared_expert_dp", False) + and vllm_config.parallel_config.enable_expert_parallel + ) if self.enable_shared_expert_dp: from vllm_ascend.utils import enable_sp - assert enable_sp(vllm_config=vllm_config, - enable_shared_expert_dp=True) - self.multistream_overlap_shared_expert = additional_config.get( - "multistream_overlap_shared_expert", False) - self.multistream_overlap_gate = additional_config.get( - "multistream_overlap_gate", False) - self.recompute_scheduler_enable = additional_config.get( - "recompute_scheduler_enable", False) - self.enable_cpu_binding = additional_config.get( - "enable_cpu_binding", False) + + assert enable_sp(vllm_config=vllm_config, enable_shared_expert_dp=True) + self.multistream_overlap_shared_expert = additional_config.get("multistream_overlap_shared_expert", False) + self.multistream_overlap_gate = additional_config.get("multistream_overlap_gate", False) + self.recompute_scheduler_enable = additional_config.get("recompute_scheduler_enable", False) + self.enable_cpu_binding = additional_config.get("enable_cpu_binding", False) self.pd_tp_ratio = 1 self.pd_head_ratio = 1 self.num_head_replica = 1 if vllm_config.kv_transfer_config is not None and not vllm_config.model_config.is_deepseek_mla: - prefill_tp_size = vllm_config.kv_transfer_config.get_from_extra_config( - "prefill", {"tp_size": 1})["tp_size"] - decode_tp_size = vllm_config.kv_transfer_config.get_from_extra_config( - "decode", {"tp_size": 1})["tp_size"] + prefill_tp_size = vllm_config.kv_transfer_config.get_from_extra_config("prefill", {"tp_size": 1})["tp_size"] + decode_tp_size = vllm_config.kv_transfer_config.get_from_extra_config("decode", {"tp_size": 1})["tp_size"] assert prefill_tp_size % decode_tp_size == 0, "Prefill TP size must be divisible by Decode TP size." self.pd_tp_ratio = prefill_tp_size // decode_tp_size if self.pd_tp_ratio > 1: @@ -106,36 +94,29 @@ class AscendConfig: ) if self.pd_tp_ratio == 0: - raise AssertionError( - "Only support P node tp size lagger then D node tp size") - self.SLO_limits_for_dynamic_batch = additional_config.get( - "SLO_limits_for_dynamic_batch", -1) + raise AssertionError("Only support P node tp size lagger then D node tp size") + self.SLO_limits_for_dynamic_batch = additional_config.get("SLO_limits_for_dynamic_batch", -1) from vllm_ascend.utils import get_flashcomm2_config_and_validate - self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_config_and_validate( - self, vllm_config) - self.enable_npugraph_ex = additional_config.get( - "enable_npugraph_ex", False) + + self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_config_and_validate(self, vllm_config) + self.enable_npugraph_ex = additional_config.get("enable_npugraph_ex", False) # We find that _npu_paged_attention still performs better than # npu_fused_infer_attention_score in some cases. We allow to execute # _npu_paged_attention in this cases. This should be removed once # npu_fused_infer_attention_score performs better on all scenarios. self.pa_shape_list = additional_config.get("pa_shape_list", []) - self.enable_async_exponential = bool( - additional_config.get("enable_async_exponential", False)) + self.enable_async_exponential = bool(additional_config.get("enable_async_exponential", False)) self.enable_kv_nz = additional_config.get("enable_kv_nz", False) if self.enable_kv_nz: - use_sparse = hasattr(vllm_config.model_config.hf_text_config, - "index_topk") + use_sparse = hasattr(vllm_config.model_config.hf_text_config, "index_topk") if not vllm_config.model_config.is_deepseek_mla or use_sparse: - raise RuntimeError( - "enable_kv_nz is only supported for mla currently.") - if vllm_config.kv_transfer_config is None \ - or not vllm_config.kv_transfer_config.is_kv_consumer: + raise RuntimeError("enable_kv_nz is only supported for mla currently.") + if vllm_config.kv_transfer_config is None or not vllm_config.kv_transfer_config.is_kv_consumer: raise NotImplementedError( - "enable_kv_nz is only supported in pd scenario and can " - "only be used in D node.") + "enable_kv_nz is only supported in pd scenario and can only be used in D node." + ) class FinegrainedTPConfig: @@ -144,40 +125,28 @@ class FinegrainedTPConfig: """ def __init__(self, finegrained_tp_config: dict, vllm_config): - self.oproj_tensor_parallel_size = finegrained_tp_config.get( - "oproj_tensor_parallel_size", 0) - self.lmhead_tensor_parallel_size = finegrained_tp_config.get( - "lmhead_tensor_parallel_size", 0) - self.embedding_tensor_parallel_size = finegrained_tp_config.get( - "embedding_tensor_parallel_size", 0) - self.mlp_tensor_parallel_size = finegrained_tp_config.get( - "mlp_tensor_parallel_size", 0) + self.oproj_tensor_parallel_size = finegrained_tp_config.get("oproj_tensor_parallel_size", 0) + self.lmhead_tensor_parallel_size = finegrained_tp_config.get("lmhead_tensor_parallel_size", 0) + self.embedding_tensor_parallel_size = finegrained_tp_config.get("embedding_tensor_parallel_size", 0) + self.mlp_tensor_parallel_size = finegrained_tp_config.get("mlp_tensor_parallel_size", 0) enabled_configs = [] if self.oproj_tensor_parallel_size > 0: - enabled_configs.append( - f"oproj_tensor_parallel_size={self.oproj_tensor_parallel_size}" - ) - # dummy_run does not run the entire attention module in eager mode,, so the o_proj tp split can only be used in graph mode. + enabled_configs.append(f"oproj_tensor_parallel_size={self.oproj_tensor_parallel_size}") + # dummy_run does not run the entire attention module in eager mode, + # so the o_proj tp split can only be used in graph mode. if vllm_config.model_config.enforce_eager is True: - raise AssertionError( - "oproj_tensor_parallel_size is only supported in graph mode" - ) + raise AssertionError("oproj_tensor_parallel_size is only supported in graph mode") if vllm_config.kv_transfer_config is None or not vllm_config.kv_transfer_config.is_kv_consumer: raise AssertionError( "oproj_tensor_parallel_size is only supported in pd scenario and can only be used in D node." ) if self.lmhead_tensor_parallel_size > 0: - enabled_configs.append( - f"lmhead_tensor_parallel_size={self.lmhead_tensor_parallel_size}" - ) + enabled_configs.append(f"lmhead_tensor_parallel_size={self.lmhead_tensor_parallel_size}") if self.embedding_tensor_parallel_size > 0: - enabled_configs.append( - f"embedding_tensor_parallel_size={self.embedding_tensor_parallel_size}" - ) + enabled_configs.append(f"embedding_tensor_parallel_size={self.embedding_tensor_parallel_size}") if self.mlp_tensor_parallel_size > 0: - enabled_configs.append( - f"mlp_tensor_parallel_size={self.mlp_tensor_parallel_size}") + enabled_configs.append(f"mlp_tensor_parallel_size={self.mlp_tensor_parallel_size}") module_tp_sizes = [ self.oproj_tensor_parallel_size, self.lmhead_tensor_parallel_size, @@ -186,11 +155,9 @@ class FinegrainedTPConfig: ] for module_tp_size in module_tp_sizes: if module_tp_size > 0 and vllm_config.parallel_config.data_parallel_size % module_tp_size != 0: - raise AssertionError( - "module tp sizes must divide data_parallel_size") + raise AssertionError("module tp sizes must divide data_parallel_size") if any(size > 0 for size in module_tp_sizes) and enabled_configs: - logger.info( - f"finegrained_tp_config enabled: {', '.join(enabled_configs)}") + logger.info(f"finegrained_tp_config enabled: {', '.join(enabled_configs)}") class AscendCompilationConfig: @@ -202,13 +169,10 @@ class AscendCompilationConfig: deployed on Ascend platforms. """ - def __init__(self, - fuse_norm_quant: bool = True, - fuse_qknorm_rope: bool = False, - **kwargs): + def __init__(self, fuse_norm_quant: bool = True, fuse_qknorm_rope: bool = False, **kwargs): """ Initialize the configuration. - + Args: fuse_norm_quant (bool): Whether to enable norm and quant fusion optimization. When set to True, the system will optimize norm and quant operations. @@ -236,7 +200,8 @@ class XliteGraphConfig: ) if vllm_config.parallel_config.pipeline_parallel_size > 1: raise RuntimeError( - "Xlite graph mode is not compatible with pipeline parallelism. Please set pipeline_parallel_size to 1." + "Xlite graph mode is not compatible with pipeline parallelism. " + "Please set pipeline_parallel_size to 1." ) if vllm_config.cache_config.block_size != 128: raise RuntimeError( @@ -254,21 +219,19 @@ class WeightPrefetchConfig: "qkv": 1.0, "o": 1.0, }, - "moe": { - "gate_up": 0.8 - } + "moe": {"gate_up": 0.8}, } def __init__(self, weight_prefetch_config: dict): self.enabled = weight_prefetch_config.get("enabled", False) - self.prefetch_ratio = weight_prefetch_config.get( - "prefetch_ratio", self.prefetch_ratio) + self.prefetch_ratio = weight_prefetch_config.get("prefetch_ratio", self.prefetch_ratio) class EplbConfig: """ Configuration Object for xlite_graph_config from additional_config """ + _defaults = { "dynamic_eplb": False, "expert_map_path": None, @@ -276,10 +239,12 @@ class EplbConfig: "algorithm_execution_interval": 30, "expert_map_record_path": None, "num_redundant_experts": 0, - "eplb_policy_type": 1 + "eplb_policy_type": 1, } - def __init__(self, user_config: dict = {}): + def __init__(self, user_config: dict | None = None): + if user_config is None: + user_config = {} self.config = self._defaults.copy() if user_config and isinstance(user_config, dict): for key, value in user_config.items(): @@ -307,27 +272,21 @@ class EplbConfig: raise TypeError("The expert_map_record_path is not json.") dirname = os.path.dirname(self.expert_map_record_path) os.makedirs(dirname, exist_ok=True) - for key in [ - "expert_heat_collection_interval", - "algorithm_execution_interval", "num_redundant_experts" - ]: + for key in ["expert_heat_collection_interval", "algorithm_execution_interval", "num_redundant_experts"]: if not isinstance(self.config[key], int): raise TypeError(f"{key} must be an integer") if self.config[key] < 0: # type: ignore - raise ValueError( - f"{key} must greater than 0; got {self.config[key]} instead" - ) + raise ValueError(f"{key} must greater than 0; got {self.config[key]} instead") if self.eplb_policy_type not in [0, 1, 2, 3]: raise ValueError("eplb_policy_type must in [0, 1, 2, 3]") -_ASCEND_CONFIG: Optional[AscendConfig] = None +_ASCEND_CONFIG: AscendConfig | None = None def init_ascend_config(vllm_config): additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {} - refresh = additional_config.get("refresh", - False) if additional_config else False + refresh = additional_config.get("refresh", False) if additional_config else False global _ASCEND_CONFIG if _ASCEND_CONFIG is not None and not refresh: return _ASCEND_CONFIG @@ -343,7 +302,5 @@ def clear_ascend_config(): def get_ascend_config(): global _ASCEND_CONFIG if _ASCEND_CONFIG is None: - raise RuntimeError( - "Ascend config is not initialized. Please call init_ascend_config first." - ) + raise RuntimeError("Ascend config is not initialized. Please call init_ascend_config first.") return _ASCEND_CONFIG diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 092544b7..49693474 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -1,21 +1,25 @@ import math from contextlib import contextmanager from enum import Enum -from typing import Any, Optional +from typing import Any import torch from vllm.config import CUDAGraphMode, VllmConfig -from vllm.distributed import (get_dp_group, get_ep_group, - get_tensor_model_parallel_world_size) -from vllm.forward_context import (BatchDescriptor, get_forward_context, - set_forward_context) +from vllm.distributed import get_dp_group, get_ep_group, get_tensor_model_parallel_world_size +from vllm.forward_context import BatchDescriptor, get_forward_context, set_forward_context import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.utils import (AscendDeviceType, enable_sp, flashcomm2_enable, - get_ascend_device_type, has_layer_idx, - is_drafter_moe_model, is_moe_model, - speculative_enable_dispatch_gmm_combine_decode) +from vllm_ascend.utils import ( + AscendDeviceType, + enable_sp, + flashcomm2_enable, + get_ascend_device_type, + has_layer_idx, + is_drafter_moe_model, + is_moe_model, + speculative_enable_dispatch_gmm_combine_decode, +) class MoECommType(Enum): @@ -27,36 +31,36 @@ class MoECommType(Enum): @contextmanager def set_ascend_forward_context( - attn_metadata: Any, - vllm_config: VllmConfig, - virtual_engine: int = 0, - num_tokens: int = 0, - num_tokens_across_dp: Optional[torch.Tensor] = None, - in_profile_run: bool = False, - num_actual_tokens: Optional[int] = None, - aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor: Optional[BatchDescriptor] = None, - model_instance: torch.nn.Module = None, - is_draft_model=False): + attn_metadata: Any, + vllm_config: VllmConfig, + virtual_engine: int = 0, + num_tokens: int = 0, + num_tokens_across_dp: torch.Tensor | None = None, + in_profile_run: bool = False, + num_actual_tokens: int | None = None, + aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor: BatchDescriptor | None = None, + model_instance: torch.nn.Module = None, + is_draft_model=False, +): """A context manager that stores the current forward context, can be attention metadata, etc. We add some additional param into forward_context. """ with set_forward_context( - attn_metadata, - vllm_config, - virtual_engine=virtual_engine, - num_tokens=num_tokens, - num_tokens_across_dp=num_tokens_across_dp, - cudagraph_runtime_mode=aclgraph_runtime_mode, - batch_descriptor=batch_descriptor, + attn_metadata, + vllm_config, + virtual_engine=virtual_engine, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=batch_descriptor, ): forward_context = get_forward_context() - from vllm_ascend.ops.fused_moe.moe_comm_method import \ - get_moe_comm_method - moe_comm_type = select_moe_comm_method(num_tokens, vllm_config, - is_draft_model) + from vllm_ascend.ops.fused_moe.moe_comm_method import get_moe_comm_method + + moe_comm_type = select_moe_comm_method(num_tokens, vllm_config, is_draft_model) forward_context.moe_comm_type = moe_comm_type forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type) @@ -68,31 +72,28 @@ def set_ascend_forward_context( # due to multiple warmups before actual capturing forward_context.capturing = False - # set for sequence parallelism, 1000 is the batch size concurrency threshold for enabling the flashcomm_v1 or sequence_parallelism feature. - # Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold, - # the performance benefits can be maximized. Conversely, if the concurrency is below the threshold, + # set for sequence parallelism, 1000 is the batch size concurrency threshold + # for enabling the flashcomm_v1 or sequence_parallelism feature. + # Currently, it is an empirical value. In normal scenarios, if the concurrency + # exceeds this threshold, the performance benefits can be maximized. + # Conversely, if the concurrency is below the threshold, # the performance may degrade due to the switching of communication methods. mmrs_fusion = True # main model and drafter model may have different architecture - is_context_moe_model = is_drafter_moe_model(vllm_config) \ - if is_draft_model else is_moe_model(vllm_config) + is_context_moe_model = is_drafter_moe_model(vllm_config) if is_draft_model else is_moe_model(vllm_config) if is_context_moe_model: sp_enabled = enable_sp(vllm_config) and num_tokens is not None mmrs_fusion = False else: - sp_enabled = enable_sp(vllm_config) and \ - num_tokens is not None and num_tokens > 1000 + sp_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000 forward_context.mmrs_fusion = mmrs_fusion forward_context.num_tokens = num_tokens forward_context.sp_enabled = sp_enabled # TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2 - forward_context.flashcomm_v2_enabled = flashcomm2_enable( - ) and tp_world_size > 1 and num_tokens is not None + forward_context.flashcomm_v2_enabled = flashcomm2_enable() and tp_world_size > 1 and num_tokens is not None - if (forward_context.sp_enabled - or forward_context.flashcomm_v2_enabled): - pad_size = (tp_world_size - - (num_tokens % tp_world_size)) % tp_world_size + if forward_context.sp_enabled or forward_context.flashcomm_v2_enabled: + pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size forward_context.pad_size = pad_size # set this for rope forward_oot using @@ -106,9 +107,12 @@ def set_ascend_forward_context( # TODO(rjg-lyh): refactor mlp weight prefetch method # set for mlp weight prefetch - prefetch_mlp_enabled = envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \ - forward_context.layer_idx is not None and \ - num_tokens is not None and num_tokens < 500 + prefetch_mlp_enabled = ( + envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP + and forward_context.layer_idx is not None + and num_tokens is not None + and num_tokens < 500 + ) if prefetch_mlp_enabled: forward_context.prefetch_mlp_gate_up_proj = False forward_context.prefetch_mlp_down_proj = False @@ -121,12 +125,9 @@ def set_ascend_forward_context( dp_world_size = get_dp_group().world_size if dp_world_size > 1 and forward_context.dp_metadata is not None: - max_tokens_across_dp = \ - forward_context.dp_metadata.max_tokens_across_dp_cpu.item() - if (forward_context.sp_enabled - or forward_context.flashcomm_v2_enabled): - padded_length = (max_tokens_across_dp + tp_world_size - - 1) // tp_world_size * tp_world_size + max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item() + if forward_context.sp_enabled or forward_context.flashcomm_v2_enabled: + padded_length = (max_tokens_across_dp + tp_world_size - 1) // tp_world_size * tp_world_size pad_size = padded_length - num_tokens forward_context.padded_length = padded_length forward_context.pad_size = pad_size @@ -139,12 +140,10 @@ def set_ascend_forward_context( if num_actual_tokens is None: num_actual_tokens = num_tokens # NOTE: token num which need to pad to when mc2 - forward_context.padded_num_tokens = math.ceil( - max_tokens_across_dp / tp_world_size) * tp_world_size + forward_context.padded_num_tokens = math.ceil(max_tokens_across_dp / tp_world_size) * tp_world_size reserved_mc2_mask = get_mc2_mask() if reserved_mc2_mask is not None: - mc2_mask = reserved_mc2_mask[:forward_context. - padded_num_tokens] + mc2_mask = reserved_mc2_mask[: forward_context.padded_num_tokens] mc2_mask[:num_actual_tokens] = True mc2_mask[num_actual_tokens:] = False forward_context.mc2_mask = mc2_mask @@ -155,14 +154,13 @@ def set_ascend_forward_context( pass -_mc2_tokens_capacity: Optional[int] = None -_reserved_mc2_mask: Optional[torch.Tensor] = None -_sin: Optional[torch.Tensor] = None -_cos: Optional[torch.Tensor] = None +_mc2_tokens_capacity: int | None = None +_reserved_mc2_mask: torch.Tensor | None = None +_sin: torch.Tensor | None = None +_cos: torch.Tensor | None = None -def set_mc2_tokens_capacity(vllm_config, max_num_reqs, - uniform_decode_query_len): +def set_mc2_tokens_capacity(vllm_config, max_num_reqs, uniform_decode_query_len): global _mc2_tokens_capacity if _mc2_tokens_capacity is not None: return @@ -186,9 +184,7 @@ def set_mc2_mask(vllm_config, device): if _reserved_mc2_mask is not None: return if is_moe_model(vllm_config): - _reserved_mc2_mask = torch.zeros(get_mc2_tokens_capacity(), - dtype=torch.bool, - device=device) + _reserved_mc2_mask = torch.zeros(get_mc2_tokens_capacity(), dtype=torch.bool, device=device) else: _reserved_mc2_mask = None @@ -197,9 +193,7 @@ def get_mc2_mask(): return _reserved_mc2_mask -def select_moe_comm_method(num_tokens: int, - vllm_config: VllmConfig, - is_draft_model=False) -> Optional[MoECommType]: +def select_moe_comm_method(num_tokens: int, vllm_config: VllmConfig, is_draft_model=False) -> MoECommType | None: """Select the MoE communication method according to parallel settings, device generation, token count, and quantization. @@ -227,16 +221,19 @@ def select_moe_comm_method(num_tokens: int, mc2_tokens_capacity = get_mc2_tokens_capacity() soc_version = get_ascend_device_type() quant_type = getattr( - vllm_config.model_config.hf_text_config, 'moe_quantize', - getattr(vllm_config.model_config.hf_text_config, 'quantize', None)) + vllm_config.model_config.hf_text_config, + "moe_quantize", + getattr(vllm_config.model_config.hf_text_config, "quantize", None), + ) - if not vllm_config.parallel_config.enable_expert_parallel or get_ep_group( - ).world_size == 1: + if not vllm_config.parallel_config.enable_expert_parallel or get_ep_group().world_size == 1: moe_comm_type = MoECommType.ALLGATHER elif soc_version in {AscendDeviceType.A2}: - if (num_tokens <= mc2_tokens_capacity - and vllm_config.parallel_config.world_size_across_dp / - vllm_config.parallel_config.pipeline_parallel_size >= 16): + if ( + num_tokens <= mc2_tokens_capacity + and vllm_config.parallel_config.world_size_across_dp / vllm_config.parallel_config.pipeline_parallel_size + >= 16 + ): moe_comm_type = MoECommType.MC2 else: moe_comm_type = MoECommType.ALLGATHER @@ -246,15 +243,13 @@ def select_moe_comm_method(num_tokens: int, # TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes # TODO: drop speculative method guard when dispatch_gmm_combine_decode supports w16a16 fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic" - dispatch_ffn_combine_enable = get_ep_group().world_size <= 32 and ( - not is_draft_model) and (not dynamic_eplb) + dispatch_ffn_combine_enable = get_ep_group().world_size <= 32 and (not is_draft_model) and (not dynamic_eplb) if num_tokens <= mc2_tokens_capacity: fused_decode_enable = fused_mc2_enable if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1: fused_decode_enable = fused_mc2_enable and dispatch_ffn_combine_enable elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2: - fused_decode_enable = fused_mc2_enable and \ - speculative_enable_dispatch_gmm_combine_decode(vllm_config) + fused_decode_enable = fused_mc2_enable and speculative_enable_dispatch_gmm_combine_decode(vllm_config) moe_comm_type = MoECommType.FUSED_MC2 if fused_decode_enable else MoECommType.MC2 else: fused_prefill_enable = fused_mc2_enable diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 97bfb03d..c81d2d2f 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -2,9 +2,10 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import dataclasses +from collections.abc import Callable from contextlib import ExitStack from dataclasses import dataclass -from typing import Any, Callable, Optional +from typing import Any from unittest.mock import patch import numpy as np @@ -27,12 +28,12 @@ from ..utils import weak_ref_tensors @dataclasses.dataclass class ACLGraphEntry: batch_descriptor: BatchDescriptor - aclgraph: Optional[torch.npu.NPUGraph] = None - output: Optional[Any] = None + aclgraph: torch.npu.NPUGraph | None = None + output: Any | None = None # for aclgraph debugging, track the input addresses # during capture, and check if they are the same during replay - input_addresses: Optional[list[int]] = None + input_addresses: list[int] | None = None class ACLGraphWrapper: @@ -60,11 +61,13 @@ class ACLGraphWrapper: guaranteed when VLLM_LOGGING_LEVEL == "DEBUG". """ - def __init__(self, - runnable: Callable, - vllm_config: VllmConfig, - runtime_mode: CUDAGraphMode, - cudagraph_options: Optional[CUDAGraphOptions] = None): + def __init__( + self, + runnable: Callable, + vllm_config: VllmConfig, + runtime_mode: CUDAGraphMode, + cudagraph_options: CUDAGraphOptions | None = None, + ): self.runnable = runnable self.vllm_config = vllm_config self.runtime_mode = runtime_mode @@ -83,15 +86,13 @@ class ACLGraphWrapper: self.aclgraph_options = cudagraph_options # the entries for different batch descriptors that we need to capture # aclgraphs for. - self.concrete_aclgraph_entries: dict[BatchDescriptor, ACLGraphEntry]\ - = {} + self.concrete_aclgraph_entries: dict[BatchDescriptor, ACLGraphEntry] = {} def __getattr__(self, key: str): # allow accessing the attributes of the runnable. if hasattr(self.runnable, key): return getattr(self.runnable, key) - raise AttributeError(f"Attribute {key} not exists in the runnable of " - f"aclgraph wrapper: {self.runnable}") + raise AttributeError(f"Attribute {key} not exists in the runnable of aclgraph wrapper: {self.runnable}") def unwrap(self) -> Callable: # in case we need to access the original runnable. @@ -102,8 +103,7 @@ class ACLGraphWrapper: batch_descriptor = forward_context.batch_descriptor aclgraph_runtime_mode = forward_context.cudagraph_runtime_mode - if aclgraph_runtime_mode == CUDAGraphMode.NONE or \ - aclgraph_runtime_mode != self.runtime_mode: + if aclgraph_runtime_mode == CUDAGraphMode.NONE or aclgraph_runtime_mode != self.runtime_mode: # CUDAGraphMode.NONE could mean the profile run, a warmup run, or # running without aclgraphs. # We do not trigger capture/replay if the runtime mode is not @@ -114,8 +114,7 @@ class ACLGraphWrapper: if batch_descriptor not in self.concrete_aclgraph_entries: # create a new entry for this batch descriptor - self.concrete_aclgraph_entries[batch_descriptor] = \ - ACLGraphEntry(batch_descriptor=batch_descriptor) + self.concrete_aclgraph_entries[batch_descriptor] = ACLGraphEntry(batch_descriptor=batch_descriptor) entry = self.concrete_aclgraph_entries[batch_descriptor] @@ -125,14 +124,11 @@ class ACLGraphWrapper: # capturing is fast, we don't need to log it for every # shape. E.g. we only log it for the first subgraph in # piecewise mode. - logger.debug("Capturing a aclgraph on (%s,%s)", - self.runtime_mode.name, entry.batch_descriptor) + logger.debug("Capturing a aclgraph on (%s,%s)", self.runtime_mode.name, entry.batch_descriptor) # validate that aclgraph capturing is legal at this point. validate_cudagraph_capturing_enabled() - input_addresses = [ - x.data_ptr() for x in args if isinstance(x, torch.Tensor) - ] + input_addresses = [x.data_ptr() for x in args if isinstance(x, torch.Tensor)] entry.input_addresses = input_addresses aclgraph = torch.npu.NPUGraph() @@ -145,8 +141,7 @@ class ACLGraphWrapper: # therefore, we only run gc for the first graph, # and disable gc for the rest of the graphs. stack.enter_context(patch("gc.collect", lambda: None)) - stack.enter_context( - patch("torch.npu.empty_cache", lambda: None)) + stack.enter_context(patch("torch.npu.empty_cache", lambda: None)) # mind-exploding: carefully manage the reference and memory. forward_context.capturing = True @@ -183,13 +178,12 @@ class ACLGraphWrapper: if self.is_debugging_mode: # check if the input addresses are the same - new_input_addresses = [ - x.data_ptr() for x in args if isinstance(x, torch.Tensor) - ] + new_input_addresses = [x.data_ptr() for x in args if isinstance(x, torch.Tensor)] assert new_input_addresses == entry.input_addresses, ( f"Input addresses for aclgraphs are different " f"during replay. Expected {entry.input_addresses}, " - f"got {new_input_addresses}") + f"got {new_input_addresses}" + ) logger.info_once("Replaying aclgraph") # In async scheduling or multi-threaded (MT) scenarios when graph mode is FULL, it is possible that @@ -209,8 +203,7 @@ def weak_ref_workspaces(params): for num_tokens in params.workspaces: if params.workspaces[num_tokens] is None: continue - params.workspaces[num_tokens] = weak_ref_tensors( - params.workspaces[num_tokens]) + params.workspaces[num_tokens] = weak_ref_tensors(params.workspaces[num_tokens]) def _update_attn_pa_params(update_stream, forward_context, runtime_shape): @@ -219,10 +212,10 @@ def _update_attn_pa_params(update_stream, forward_context, runtime_shape): # for each layer's attention op in the graph. with torch.npu.stream(update_stream): for key, param, handle, event in zip( - forward_context.attn_metadata, - graph_params.attn_params[runtime_shape], - graph_params.handles[runtime_shape], - graph_params.events[runtime_shape], + forward_context.attn_metadata, + graph_params.attn_params[runtime_shape], + graph_params.handles[runtime_shape], + graph_params.events[runtime_shape], ): ( query, @@ -254,18 +247,21 @@ def _update_attn_pa_params(update_stream, forward_context, runtime_shape): scale_value=scale, block_table=block_table, context_lens=seq_lens, - out=output) + out=output, + ) torch.npu.graph_task_update_begin(update_stream, handle) - torch_npu._npu_paged_attention(query=query, - key_cache=key_cache, - value_cache=value_cache, - num_kv_heads=num_kv_heads, - num_heads=num_heads, - scale_value=scale, - block_table=block_table, - context_lens=seq_lens, - out=output, - workspace=workspace) + torch_npu._npu_paged_attention( + query=query, + key_cache=key_cache, + value_cache=value_cache, + num_kv_heads=num_kv_heads, + num_heads=num_heads, + scale_value=scale, + block_table=block_table, + context_lens=seq_lens, + out=output, + workspace=workspace, + ) torch.npu.graph_task_update_end(update_stream) event.record(update_stream) @@ -282,18 +278,29 @@ def _update_attn_fia_params(update_stream, forward_context, runtime_shape): # filters out the update operations for linear_attn. with torch.npu.stream(update_stream): for key, param, handle, event in zip( - forward_context.attn_metadata, - graph_params.attn_params[runtime_shape], - graph_params.handles[runtime_shape], - graph_params.events[runtime_shape], + forward_context.attn_metadata, + graph_params.attn_params[runtime_shape], + graph_params.handles[runtime_shape], + graph_params.events[runtime_shape], ): - (query, key_cache, value, block_tables, attn_mask, block_size, - seq_lens, query_start_loc, num_kv_heads, num_heads, scale, - attn_output, softmax_lse) = param + ( + query, + key_cache, + value, + block_tables, + attn_mask, + block_size, + seq_lens, + query_start_loc, + num_kv_heads, + num_heads, + scale, + attn_output, + softmax_lse, + ) = param seq_lens = forward_context.attn_metadata[key].seq_lens_list - actual_seq_lengths_q = forward_context.attn_metadata[ - key].actual_seq_lengths_q + actual_seq_lengths_q = forward_context.attn_metadata[key].actual_seq_lengths_q torch.npu.graph_task_update_begin(update_stream, handle) torch_npu.npu_fused_infer_attention_score.out( query=query, @@ -317,16 +324,14 @@ def _update_attn_fia_params(update_stream, forward_context, runtime_shape): event.record(update_stream) -def update_attn_params(update_stream, forward_context, runtime_shape, - vllm_config): +def update_attn_params(update_stream, forward_context, runtime_shape, vllm_config): if using_paged_attention(runtime_shape, vllm_config): _update_attn_pa_params(update_stream, forward_context, runtime_shape) else: _update_attn_fia_params(update_stream, forward_context, runtime_shape) -def update_mla_attn_params(update_stream, forward_context, runtime_shape, - speculative_config): +def update_mla_attn_params(update_stream, forward_context, runtime_shape, speculative_config): if forward_context.is_draft_model: graph_params = get_draft_graph_params() else: @@ -335,41 +340,44 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape, # for each layer's attention op in the graph. with torch.npu.stream(update_stream): for key, param, handle, event in zip( - forward_context.attn_metadata, - graph_params.attn_params[runtime_shape], - graph_params.handles[runtime_shape], - graph_params.events[runtime_shape], + forward_context.attn_metadata, + graph_params.attn_params[runtime_shape], + graph_params.handles[runtime_shape], + graph_params.events[runtime_shape], ): - (q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout, - attn_mask, sparse_mode, scale, block_table, block_size, - seq_lens_list, actual_seq_lengths, attn_output, - softmax_lse) = param - seq_lens_list = forward_context.attn_metadata[ - key].decode.seq_lens_list - if speculative_config and speculative_config.method == "mtp" \ - and not forward_context.is_draft_model: - actual_seq_lengths = forward_context.attn_metadata[ - key].decode.actual_seq_lengths_q + ( + q_nope, + k_nope, + q_pe, + k_pe, + num_heads, + num_kv_heads, + input_layout, + attn_mask, + sparse_mode, + scale, + block_table, + block_size, + seq_lens_list, + actual_seq_lengths, + attn_output, + softmax_lse, + ) = param + seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list + if speculative_config and speculative_config.method == "mtp" and not forward_context.is_draft_model: + actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q spec_multiple = speculative_config.num_speculative_tokens + 1 - seq_lens_list = seq_lens_list + [0] * ( - runtime_shape // spec_multiple - len(seq_lens_list)) - actual_seq_lengths = [ - spec_multiple * (i + 1) - for i in range(runtime_shape // spec_multiple) - ] + seq_lens_list = seq_lens_list + [0] * (runtime_shape // spec_multiple - len(seq_lens_list)) + actual_seq_lengths = [spec_multiple * (i + 1) for i in range(runtime_shape // spec_multiple)] elif forward_context.is_draft_model: - actual_seq_lengths = forward_context.attn_metadata[ - key].decode.actual_seq_lengths_q - block_table = forward_context.attn_metadata[ - key].decode.block_table + actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q + block_table = forward_context.attn_metadata[key].decode.block_table # TODO: This is a hack and should be fixed in the future. if speculative_config.disable_padded_drafter_batch: - block_table = block_table[:len(actual_seq_lengths)] - seq_lens_list = seq_lens_list + [0] * ( - len(actual_seq_lengths) - len(seq_lens_list)) + block_table = block_table[: len(actual_seq_lengths)] + seq_lens_list = seq_lens_list + [0] * (len(actual_seq_lengths) - len(seq_lens_list)) else: - seq_lens_list = seq_lens_list + [0] * (runtime_shape - - len(seq_lens_list)) + seq_lens_list = seq_lens_list + [0] * (runtime_shape - len(seq_lens_list)) torch.npu.graph_task_update_begin(update_stream, handle) torch_npu.npu_fused_infer_attention_score.out( @@ -391,7 +399,8 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape, actual_seq_lengths_kv=seq_lens_list, actual_seq_lengths=actual_seq_lengths, workspace=graph_params.workspaces.get(runtime_shape), - out=[attn_output, softmax_lse]) + out=[attn_output, softmax_lse], + ) torch.npu.graph_task_update_end(update_stream) event.record(update_stream) @@ -403,34 +412,40 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape): graph_params = get_graph_params() with torch.npu.stream(update_stream): for key, param, handle, event in zip( - forward_context.attn_metadata, - graph_params.attn_params[runtime_shape], - graph_params.handles[runtime_shape], - graph_params.events[runtime_shape], + forward_context.attn_metadata, + graph_params.attn_params[runtime_shape], + graph_params.handles[runtime_shape], + graph_params.events[runtime_shape], ): - (q_nope, k_nope, value, num_heads, num_kv_heads, scale, - block_table, block_size, actual_seq_lengths_kv, - actual_seq_lengths_q, attn_output, softmax_lse, dcp_size, - pcp_rank, dcp_rank) = param + ( + q_nope, + k_nope, + value, + num_heads, + num_kv_heads, + scale, + block_table, + block_size, + actual_seq_lengths_kv, + actual_seq_lengths_q, + attn_output, + softmax_lse, + dcp_size, + pcp_rank, + dcp_rank, + ) = param attn_metadata = forward_context.attn_metadata[key] - actual_seq_lengths_kv = attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[:, - pcp_rank, - dcp_rank] + actual_seq_lengths_kv = attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[:, pcp_rank, dcp_rank] pad_length = runtime_shape - len(actual_seq_lengths_kv) if pad_length > 0: - pad_tensor = np.zeros(pad_length, - dtype=actual_seq_lengths_kv.dtype) - actual_seq_lengths_kv = np.concatenate( - [actual_seq_lengths_kv, pad_tensor]) + pad_tensor = np.zeros(pad_length, dtype=actual_seq_lengths_kv.dtype) + actual_seq_lengths_kv = np.concatenate([actual_seq_lengths_kv, pad_tensor]) - actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q[: - attn_metadata - . - num_decode_tokens] - if (runtime_shape - len(actual_seq_lengths_q)): - actual_seq_lengths_q = actual_seq_lengths_q + [ - actual_seq_lengths_q[-1] - ] * (runtime_shape - len(actual_seq_lengths_q)) + actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q[: attn_metadata.num_decode_tokens] + if runtime_shape - len(actual_seq_lengths_q): + actual_seq_lengths_q = actual_seq_lengths_q + [actual_seq_lengths_q[-1]] * ( + runtime_shape - len(actual_seq_lengths_q) + ) if dcp_size > 1: num_heads = num_heads * dcp_size @@ -453,14 +468,14 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape): actual_seq_lengths_kv=actual_seq_lengths_kv, actual_seq_lengths=actual_seq_lengths_q, workspace=graph_params.workspaces.get(runtime_shape), - out=[attn_output, softmax_lse]) + out=[attn_output, softmax_lse], + ) torch.npu.graph_task_update_end(update_stream) event.record(update_stream) -def update_mla_attn_dcp_pcp_params(update_stream, forward_context, - runtime_shape): +def update_mla_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape): if forward_context.is_draft_model: graph_params = get_draft_graph_params() else: @@ -469,13 +484,24 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context, # for each layer's attention op in the graph. with torch.npu.stream(update_stream): for key, param, handle, event in zip( - forward_context.attn_metadata, - graph_params.attn_params[runtime_shape], - graph_params.handles[runtime_shape], - graph_params.events[runtime_shape], + forward_context.attn_metadata, + graph_params.attn_params[runtime_shape], + graph_params.handles[runtime_shape], + graph_params.events[runtime_shape], ): - (q_nope, q_pe, k_nope, k_pe, block_table, seq_len, num_heads, - scale, num_kv_heads, attn_output, softmax_lse) = param + ( + q_nope, + q_pe, + k_nope, + k_pe, + block_table, + seq_len, + num_heads, + scale, + num_kv_heads, + attn_output, + softmax_lse, + ) = param decode_meta = forward_context.attn_metadata[key].decode seq_len = decode_meta.cp_seq_len @@ -484,9 +510,7 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context, # to avoid irregular attn_mask shape, # so there's no need to divide runtime_shape by spec_multiple pad_length = runtime_shape - len(seq_len) - pad_tensor = torch.zeros(pad_length, - dtype=seq_len.dtype, - device=seq_len.device) + pad_tensor = torch.zeros(pad_length, dtype=seq_len.dtype, device=seq_len.device) seq_len = torch.cat([seq_len, pad_tensor], dim=0) torch.npu.graph_task_update_begin(update_stream, handle) @@ -505,7 +529,8 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context, calc_type="calc_type_ring", workspace=graph_params.workspaces.get(runtime_shape), output=attn_output, - lse=softmax_lse) + lse=softmax_lse, + ) torch.npu.graph_task_update_end(update_stream) event.record(update_stream) @@ -519,7 +544,7 @@ class GraphParams: attn_params: dict[int, list[tuple]] -_graph_params: Optional[GraphParams] = None +_graph_params: GraphParams | None = None def set_graph_params(aclgraph_capture_sizes: list[int]): @@ -527,14 +552,10 @@ def set_graph_params(aclgraph_capture_sizes: list[int]): if _graph_params is not None: raise ValueError("Graph parameters have already been set!") _graph_params = GraphParams( - {size: [] - for size in aclgraph_capture_sizes}, - {size: None - for size in aclgraph_capture_sizes}, - {size: [] - for size in aclgraph_capture_sizes}, - {size: [] - for size in aclgraph_capture_sizes}, + {size: [] for size in aclgraph_capture_sizes}, + {size: None for size in aclgraph_capture_sizes}, + {size: [] for size in aclgraph_capture_sizes}, + {size: [] for size in aclgraph_capture_sizes}, ) @@ -548,7 +569,7 @@ def get_graph_params(): return _graph_params -_draft_graph_params: Optional[GraphParams] = None +_draft_graph_params: GraphParams | None = None def set_draft_graph_params(aclgraph_capture_sizes: list[int]): @@ -556,14 +577,10 @@ def set_draft_graph_params(aclgraph_capture_sizes: list[int]): if _draft_graph_params is not None: raise ValueError("DraftGraph parameters have already been set!") _draft_graph_params = GraphParams( - {size: [] - for size in aclgraph_capture_sizes}, - {size: None - for size in aclgraph_capture_sizes}, - {size: [] - for size in aclgraph_capture_sizes}, - {size: [] - for size in aclgraph_capture_sizes}, + {size: [] for size in aclgraph_capture_sizes}, + {size: None for size in aclgraph_capture_sizes}, + {size: [] for size in aclgraph_capture_sizes}, + {size: [] for size in aclgraph_capture_sizes}, ) diff --git a/vllm_ascend/compilation/compiler_interface.py b/vllm_ascend/compilation/compiler_interface.py index 21bdedf7..f0e77410 100644 --- a/vllm_ascend/compilation/compiler_interface.py +++ b/vllm_ascend/compilation/compiler_interface.py @@ -16,13 +16,13 @@ # limitations under the License. # import functools -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any import torch import torch.fx as fx from torch._dynamo.backends.common import aot_autograd -from torch._inductor.compile_fx import (graph_returns_tuple, - make_graph_return_tuple) +from torch._inductor.compile_fx import graph_returns_tuple, make_graph_return_tuple from torch._inductor.decomposition import select_decomp_table from torch.fx import GraphModule from vllm.compilation.compiler_interface import CompilerInterface @@ -32,15 +32,11 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.utils import COMPILATION_PASS_KEY -def compile_fx(graph: GraphModule, example_inputs: list, - inner_compile: Callable, decompositions: dict) -> Callable: - recursive_compile_fx = functools.partial(compile_fx, - inner_compile=inner_compile, - decompositions=decompositions) +def compile_fx(graph: GraphModule, example_inputs: list, inner_compile: Callable, decompositions: dict) -> Callable: + recursive_compile_fx = functools.partial(compile_fx, inner_compile=inner_compile, decompositions=decompositions) if not graph_returns_tuple(graph): - return make_graph_return_tuple(graph, example_inputs, - recursive_compile_fx) + return make_graph_return_tuple(graph, example_inputs, recursive_compile_fx) return aot_autograd(fw_compiler=inner_compile)(graph, example_inputs) @@ -49,9 +45,8 @@ def fusion_pass_compile( example_inputs: list[Any], compiler_config: dict[str, Any], compile_range: Range, - key: Optional[str] = None, -) -> tuple[Optional[Callable], Optional[Any]]: - + key: str | None = None, +) -> tuple[Callable | None, Any | None]: def compile_inner(graph, example_inputs): current_pass_manager = compiler_config[COMPILATION_PASS_KEY] graph = current_pass_manager(graph) @@ -74,8 +69,8 @@ def npugraph_ex_compile( example_inputs: list[Any], compiler_config: dict[str, Any], compile_range: Range, - key: Optional[str] = None, -) -> tuple[Optional[Callable], Optional[Any]]: + key: str | None = None, +) -> tuple[Callable | None, Any | None]: # When currently using the FULL_DECODE_ONLY mode, # the piecewise compilation level slicing process # in vllm is also encountered. @@ -87,10 +82,8 @@ def npugraph_ex_compile( output_node = fx_graph.output_node() with fx_graph.inserting_before(output_node): return_value = output_node.args[0] - tuple_node = fx_graph.create_node("call_function", - tuple, - args=([return_value], )) - output_node.args = (tuple_node, ) + tuple_node = fx_graph.create_node("call_function", tuple, args=([return_value],)) + output_node.args = (tuple_node,) graph.recompile() import torchair @@ -119,6 +112,7 @@ class AscendCompiler(CompilerInterface): This class provides a method to compile a PyTorch FX graph module with specific configurations for graph fusion and decomposition. """ + name = "AscendCompiler" def compile( @@ -127,13 +121,10 @@ class AscendCompiler(CompilerInterface): example_inputs: list[Any], compiler_config: dict[str, Any], compile_range: Range, - key: Optional[str] = None, - ) -> tuple[Optional[Callable], Optional[Any]]: - + key: str | None = None, + ) -> tuple[Callable | None, Any | None]: ascend_config = get_ascend_config() if ascend_config.enable_npugraph_ex: - return npugraph_ex_compile(graph, example_inputs, compiler_config, - compile_range, key) + return npugraph_ex_compile(graph, example_inputs, compiler_config, compile_range, key) else: - return fusion_pass_compile(graph, example_inputs, compiler_config, - compile_range, key) + return fusion_pass_compile(graph, example_inputs, compiler_config, compile_range, key) diff --git a/vllm_ascend/compilation/graph_fusion_pass_manager.py b/vllm_ascend/compilation/graph_fusion_pass_manager.py index d4dab40e..7fc3f712 100644 --- a/vllm_ascend/compilation/graph_fusion_pass_manager.py +++ b/vllm_ascend/compilation/graph_fusion_pass_manager.py @@ -26,7 +26,7 @@ class GraphFusionPassManager: """ A pass manager for graph fusion passes. It handles the configuration and execution of passes. - The counterpart in vllm is PostGradPassManager. Since torch_npu + The counterpart in vllm is PostGradPassManager. Since torch_npu does not support triton for now, we define our own pass manager. """ @@ -48,13 +48,13 @@ class GraphFusionPassManager: def configure(self, config: VllmConfig): # By default, we enable the graph fusion and quantization fusion pass. - self.ascend_compilation_config: dict = config.additional_config.get( - "ascend_compilation_config", {}) + self.ascend_compilation_config: dict = config.additional_config.get("ascend_compilation_config", {}) if self.ascend_compilation_config.get("fuse_norm_quant", True): - from .passes.norm_quant_fusion_pass import \ - AddRMSNormQuantFusionPass + from .passes.norm_quant_fusion_pass import AddRMSNormQuantFusionPass + self.passes.append(AddRMSNormQuantFusionPass(config)) if self.ascend_compilation_config.get("fuse_qknorm_rope", True): from .passes.qknorm_rope_fusion_pass import QKNormRopeFusionPass + self.passes.append(QKNormRopeFusionPass(config)) diff --git a/vllm_ascend/compilation/npugraph_ex_passes/add_rms_norm_quant.py b/vllm_ascend/compilation/npugraph_ex_passes/add_rms_norm_quant.py index 0c12e68d..9da4548c 100644 --- a/vllm_ascend/compilation/npugraph_ex_passes/add_rms_norm_quant.py +++ b/vllm_ascend/compilation/npugraph_ex_passes/add_rms_norm_quant.py @@ -48,7 +48,8 @@ def _extra_stream_scope_check(match: Match) -> bool: logger.debug( f"Cross-stream operation detected in pattern match for AddRMSNormQuant. " f"Multiple streams found: {non_default_streams}. " - f"Fusion is not supported for cross-stream operations.") + f"Fusion is not supported for cross-stream operations." + ) return False return True @@ -57,24 +58,29 @@ def _extra_stream_scope_check(match: Match) -> bool: @functools.lru_cache(None) # The replacement registered here will be actually executed after AOT. def replacement_add_rms_norm_quant(epsilon): - - def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, - rms_norm_weight: torch.Tensor, scale: torch.Tensor, - offset: torch.Tensor): + def pattern( + rms_norm_input: torch.Tensor, + residual: torch.Tensor, + rms_norm_weight: torch.Tensor, + scale: torch.Tensor, + offset: torch.Tensor, + ): """ Pattern for AddRMSNormQuant fusion. """ - output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, - rms_norm_weight, epsilon) + output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, epsilon) out0 = output[0] out1 = output[2] - quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, - torch.qint8, -1, False) + quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, torch.qint8, -1, False) return quantized_output, out1 - def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, - rms_norm_weight: torch.Tensor, scale: torch.Tensor, - offset: torch.Tensor): + def replacement( + rms_norm_input: torch.Tensor, + residual: torch.Tensor, + rms_norm_weight: torch.Tensor, + scale: torch.Tensor, + offset: torch.Tensor, + ): """ Replacement for the AddRMSNormQuant fusion. """ @@ -82,10 +88,12 @@ def replacement_add_rms_norm_quant(epsilon): rms_norm_input, residual, rms_norm_weight, - # The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel. - 1. / scale, + # The inverse of scale is required by npu_add_rms_norm_quant kernel + # which is opposite to the npu_quantize kernel. + 1.0 / scale, offset, - epsilon=epsilon) + epsilon=epsilon, + ) quantized_output = output[0] out1 = output[2] return quantized_output, out1 @@ -103,33 +111,39 @@ def replacement_add_rms_norm_quant(epsilon): import torchair - torchair.register_replacement(search_fn=pattern, - replace_fn=replacement, - example_inputs=get_inputs(), - extra_check=_extra_stream_scope_check) + torchair.register_replacement( + search_fn=pattern, replace_fn=replacement, example_inputs=get_inputs(), extra_check=_extra_stream_scope_check + ) # The replacement registered here will be actually executed after AOT. def replacement_add_rms_norm_quant_with_bias(epsilon): - - def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, - rms_norm_weight: torch.Tensor, scale: torch.Tensor, - offset: torch.Tensor, bias: torch.Tensor): + def pattern( + rms_norm_input: torch.Tensor, + residual: torch.Tensor, + rms_norm_weight: torch.Tensor, + scale: torch.Tensor, + offset: torch.Tensor, + bias: torch.Tensor, + ): """ Pattern for AddRMSNormQuantWithBias fusion. """ - output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, - rms_norm_weight, epsilon) + output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, epsilon) out0 = output[0] out1 = output[2] out0 = out0 + bias - quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, - torch.qint8, -1, False) + quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, torch.qint8, -1, False) return quantized_output, out1 - def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, - rms_norm_weight: torch.Tensor, scale: torch.Tensor, - offset: torch.Tensor, bias: torch.Tensor): + def replacement( + rms_norm_input: torch.Tensor, + residual: torch.Tensor, + rms_norm_weight: torch.Tensor, + scale: torch.Tensor, + offset: torch.Tensor, + bias: torch.Tensor, + ): """ Replacement for AddRMSNormQuantWithBias fusion. """ @@ -137,11 +151,13 @@ def replacement_add_rms_norm_quant_with_bias(epsilon): rms_norm_input, residual, rms_norm_weight, - # The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel. - 1. / scale, + # The inverse of scale is required by npu_add_rms_norm_quant kernel + # which is opposite to the npu_quantize kernel. + 1.0 / scale, offset, epsilon=epsilon, - beta=bias) + beta=bias, + ) quantized_output = output[0] out1 = output[2] return quantized_output, out1 @@ -156,40 +172,41 @@ def replacement_add_rms_norm_quant_with_bias(epsilon): rmsnorm_bias = torch.randn(4, device="npu") scale = torch.ones(4, device="npu") offset = torch.zeros(4, device="npu") - return [ - rms_norm_input, residual, rms_norm_weight, scale, offset, - rmsnorm_bias - ] + return [rms_norm_input, residual, rms_norm_weight, scale, offset, rmsnorm_bias] import torchair - torchair.register_replacement(search_fn=pattern, - replace_fn=replacement, - example_inputs=get_inputs(), - extra_check=_extra_stream_scope_check) + torchair.register_replacement( + search_fn=pattern, replace_fn=replacement, example_inputs=get_inputs(), extra_check=_extra_stream_scope_check + ) # The replacement registered here will be actually executed after AOT. def replacement_add_rms_norm_quant_sp_pattern(epsilon): - - def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, - rms_norm_weight: torch.Tensor, scale: torch.Tensor, - offset: torch.Tensor): + def pattern( + rms_norm_input: torch.Tensor, + residual: torch.Tensor, + rms_norm_weight: torch.Tensor, + scale: torch.Tensor, + offset: torch.Tensor, + ): """ Pattern for AddRMSNormQuantSPPattern fusion. """ - output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, - rms_norm_weight, epsilon) + output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, epsilon) out0 = output[0] out1 = output[2] out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True) - quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, - torch.qint8, -1, False) + quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, torch.qint8, -1, False) return quantized_output, out1 - def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, - rms_norm_weight: torch.Tensor, scale: torch.Tensor, - offset: torch.Tensor): + def replacement( + rms_norm_input: torch.Tensor, + residual: torch.Tensor, + rms_norm_weight: torch.Tensor, + scale: torch.Tensor, + offset: torch.Tensor, + ): """ Replacement for the AddRMSNormQuantSPPattern fusion. """ @@ -197,14 +214,15 @@ def replacement_add_rms_norm_quant_sp_pattern(epsilon): rms_norm_input, residual, rms_norm_weight, - # The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel. - 1. / scale, + # The inverse of scale is required by npu_add_rms_norm_quant kernel + # which is opposite to the npu_quantize kernel. + 1.0 / scale, offset, - epsilon=epsilon) + epsilon=epsilon, + ) quantized_output = output[0] out1 = output[2] - quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - quantized_output, True) + quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(quantized_output, True) return quantized_output, out1 def get_inputs(): @@ -220,34 +238,40 @@ def replacement_add_rms_norm_quant_sp_pattern(epsilon): import torchair - torchair.register_replacement(search_fn=pattern, - replace_fn=replacement, - example_inputs=get_inputs(), - extra_check=_extra_stream_scope_check) + torchair.register_replacement( + search_fn=pattern, replace_fn=replacement, example_inputs=get_inputs(), extra_check=_extra_stream_scope_check + ) # The replacement registered here will be actually executed after AOT. def replacement_add_rms_norm_quant_sp_pattern_with_bias(epsilon): - - def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, - rms_norm_weight: torch.Tensor, scale: torch.Tensor, - offset: torch.Tensor, bias: torch.Tensor): + def pattern( + rms_norm_input: torch.Tensor, + residual: torch.Tensor, + rms_norm_weight: torch.Tensor, + scale: torch.Tensor, + offset: torch.Tensor, + bias: torch.Tensor, + ): """ Pattern for AddRMSNormQuantSPPatternWithBias fusion. """ - output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, - rms_norm_weight, epsilon) + output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, epsilon) out0 = output[0] out1 = output[2] out0 = out0 + bias out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True) - quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, - torch.qint8, -1, False) + quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, torch.qint8, -1, False) return quantized_output, out1 - def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, - rms_norm_weight: torch.Tensor, scale: torch.Tensor, - offset: torch.Tensor, bias: torch.Tensor): + def replacement( + rms_norm_input: torch.Tensor, + residual: torch.Tensor, + rms_norm_weight: torch.Tensor, + scale: torch.Tensor, + offset: torch.Tensor, + bias: torch.Tensor, + ): """ Replacement for the AddRMSNormQuantSPPatternWithBias fusion. """ @@ -255,15 +279,16 @@ def replacement_add_rms_norm_quant_sp_pattern_with_bias(epsilon): rms_norm_input, residual, rms_norm_weight, - # The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel. - 1. / scale, + # The inverse of scale is required by npu_add_rms_norm_quant kernel + # which is opposite to the npu_quantize kernel. + 1.0 / scale, offset, epsilon=epsilon, - beta=bias) + beta=bias, + ) quantized_output = output[0] out1 = output[2] - quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - quantized_output, True) + quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(quantized_output, True) return quantized_output, out1 def get_inputs(): @@ -276,25 +301,19 @@ def replacement_add_rms_norm_quant_sp_pattern_with_bias(epsilon): rmsnorm_bias = torch.randn(4, device="npu") scale = torch.ones(4, device="npu") offset = torch.zeros(4, device="npu") - return [ - rms_norm_input, residual, rms_norm_weight, scale, offset, - rmsnorm_bias - ] + return [rms_norm_input, residual, rms_norm_weight, scale, offset, rmsnorm_bias] import torchair - torchair.register_replacement(search_fn=pattern, - replace_fn=replacement, - example_inputs=get_inputs(), - extra_check=_extra_stream_scope_check) + torchair.register_replacement( + search_fn=pattern, replace_fn=replacement, example_inputs=get_inputs(), extra_check=_extra_stream_scope_check + ) # register converter for pass common_epsilons = [1e-5, 1e-6] for eps in common_epsilons: - logger.info( - f"Start register fusion pattern for AddRMSNormQuant with epsilons={eps}" - ) + logger.info(f"Start register fusion pattern for AddRMSNormQuant with epsilons={eps}") replacement_add_rms_norm_quant(eps) replacement_add_rms_norm_quant_with_bias(eps) replacement_add_rms_norm_quant_sp_pattern(eps) diff --git a/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py b/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py index eeaccd80..e26b8429 100644 --- a/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py +++ b/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py @@ -25,7 +25,6 @@ from vllm.logger import logger class AddRMSNormQuantPattern: - def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6): self.vllm_config = vllm_config self.dtype = vllm_config.model_config.dtype @@ -41,50 +40,48 @@ class AddRMSNormQuantPattern: scale = torch.ones(4, device="npu", dtype=self.dtype) scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype) offset = torch.zeros(4, device="npu", dtype=self.dtype) - return [ - rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, - offset - ] + return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset] def register(self, pm_pass: PatternMatcherPass): - - def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, - rms_norm_weight: torch.Tensor, scale: torch.Tensor, - scale_reciprocal: torch.Tensor, offset: torch.Tensor): + def pattern( + rms_norm_input: torch.Tensor, + residual: torch.Tensor, + rms_norm_weight: torch.Tensor, + scale: torch.Tensor, + scale_reciprocal: torch.Tensor, + offset: torch.Tensor, + ): """ Pattern for AddRMSNormQuant fusion. """ - output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, - rms_norm_weight, self.eps) + output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps) out0 = output[0] out1 = output[2] - quantized_output = torch.ops.vllm.quantize(out0, scale, - scale_reciprocal, - offset) + quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset) return quantized_output, out1 - def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, - rms_norm_weight: torch.Tensor, scale: torch.Tensor, - scale_reciprocal: torch.Tensor, offset: torch.Tensor): + def replacement( + rms_norm_input: torch.Tensor, + residual: torch.Tensor, + rms_norm_weight: torch.Tensor, + scale: torch.Tensor, + scale_reciprocal: torch.Tensor, + offset: torch.Tensor, + ): """ Replacement for the AddRMSNormQuant fusion. """ - output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input, - residual, - rms_norm_weight, - scale, - offset, - epsilon=self.eps) + output = torch.ops.npu.npu_add_rms_norm_quant( + rms_norm_input, residual, rms_norm_weight, scale, offset, epsilon=self.eps + ) quantized_output = output[0] out1 = output[2] return quantized_output, out1 - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) class AddRMSNormQuantPatternWithBias: - def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6): self.vllm_config = vllm_config self.dtype = vllm_config.model_config.dtype @@ -101,54 +98,51 @@ class AddRMSNormQuantPatternWithBias: scale = torch.ones(4, device="npu", dtype=self.dtype) scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype) offset = torch.zeros(4, device="npu", dtype=self.dtype) - return [ - rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, - offset, rmsnorm_bias - ] + return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset, rmsnorm_bias] def register(self, pm_pass: PatternMatcherPass): - - def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, - rms_norm_weight: torch.Tensor, scale: torch.Tensor, - scale_reciprocal: torch.Tensor, offset: torch.Tensor, - bias: torch.Tensor): + def pattern( + rms_norm_input: torch.Tensor, + residual: torch.Tensor, + rms_norm_weight: torch.Tensor, + scale: torch.Tensor, + scale_reciprocal: torch.Tensor, + offset: torch.Tensor, + bias: torch.Tensor, + ): """ Pattern for AddRMSNormQuant fusion. """ - output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, - rms_norm_weight, self.eps) + output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps) out0 = output[0] out1 = output[2] out0 = out0 + bias - quantized_output = torch.ops.vllm.quantize(out0, scale, - scale_reciprocal, - offset) + quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset) return quantized_output, out1 - def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, - rms_norm_weight: torch.Tensor, scale: torch.Tensor, - scale_reciprocal: torch.Tensor, offset: torch.Tensor, - bias: torch.Tensor): + def replacement( + rms_norm_input: torch.Tensor, + residual: torch.Tensor, + rms_norm_weight: torch.Tensor, + scale: torch.Tensor, + scale_reciprocal: torch.Tensor, + offset: torch.Tensor, + bias: torch.Tensor, + ): """ Replacement for the AddRMSNormQuant fusion. """ - output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input, - residual, - rms_norm_weight, - scale, - offset, - epsilon=self.eps, - beta=bias) + output = torch.ops.npu.npu_add_rms_norm_quant( + rms_norm_input, residual, rms_norm_weight, scale, offset, epsilon=self.eps, beta=bias + ) quantized_output = output[0] out1 = output[2] return quantized_output, out1 - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) class AddRMSNormQuantSPPattern: - def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6): self.vllm_config = vllm_config self.dtype = vllm_config.model_config.dtype @@ -164,53 +158,50 @@ class AddRMSNormQuantSPPattern: scale = torch.ones(4, device="npu", dtype=self.dtype) scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype) offset = torch.zeros(4, device="npu", dtype=self.dtype) - return [ - rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, - offset - ] + return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset] def register(self, pm_pass: PatternMatcherPass): - - def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, - rms_norm_weight: torch.Tensor, scale: torch.Tensor, - scale_reciprocal: torch.Tensor, offset: torch.Tensor): + def pattern( + rms_norm_input: torch.Tensor, + residual: torch.Tensor, + rms_norm_weight: torch.Tensor, + scale: torch.Tensor, + scale_reciprocal: torch.Tensor, + offset: torch.Tensor, + ): """ Pattern for AddRMSNormQuant fusion. """ - output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, - rms_norm_weight, self.eps) + output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps) out0 = output[0] out1 = output[2] out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True) - quantized_output = torch.ops.vllm.quantize(out0, scale, - scale_reciprocal, - offset) + quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset) return quantized_output, out1 - def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, - rms_norm_weight: torch.Tensor, scale: torch.Tensor, - scale_reciprocal: torch.Tensor, offset: torch.Tensor): + def replacement( + rms_norm_input: torch.Tensor, + residual: torch.Tensor, + rms_norm_weight: torch.Tensor, + scale: torch.Tensor, + scale_reciprocal: torch.Tensor, + offset: torch.Tensor, + ): """ Replacement for the AddRMSNormQuant fusion. """ - output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input, - residual, - rms_norm_weight, - scale, - offset, - epsilon=self.eps) + output = torch.ops.npu.npu_add_rms_norm_quant( + rms_norm_input, residual, rms_norm_weight, scale, offset, epsilon=self.eps + ) quantized_output = output[0] out1 = output[2] - quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - quantized_output, True) + quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(quantized_output, True) return quantized_output, out1 - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) class AddRMSNormQuantSPPatternWithBias: - def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6): self.vllm_config = vllm_config self.dtype = vllm_config.model_config.dtype @@ -227,53 +218,50 @@ class AddRMSNormQuantSPPatternWithBias: scale = torch.ones(4, device="npu", dtype=self.dtype) scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype) offset = torch.zeros(4, device="npu", dtype=self.dtype) - return [ - rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, - offset, rmsnorm_bias - ] + return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset, rmsnorm_bias] def register(self, pm_pass: PatternMatcherPass): - - def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, - rms_norm_weight: torch.Tensor, scale: torch.Tensor, - scale_reciprocal: torch.Tensor, offset: torch.Tensor, - bias: torch.Tensor): + def pattern( + rms_norm_input: torch.Tensor, + residual: torch.Tensor, + rms_norm_weight: torch.Tensor, + scale: torch.Tensor, + scale_reciprocal: torch.Tensor, + offset: torch.Tensor, + bias: torch.Tensor, + ): """ Pattern for AddRMSNormQuant fusion. """ - output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, - rms_norm_weight, self.eps) + output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps) out0 = output[0] out1 = output[2] out0 = out0 + bias out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True) - quantized_output = torch.ops.vllm.quantize(out0, scale, - scale_reciprocal, - offset) + quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset) return quantized_output, out1 - def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, - rms_norm_weight: torch.Tensor, scale: torch.Tensor, - scale_reciprocal: torch.Tensor, offset: torch.Tensor, - bias: torch.Tensor): + def replacement( + rms_norm_input: torch.Tensor, + residual: torch.Tensor, + rms_norm_weight: torch.Tensor, + scale: torch.Tensor, + scale_reciprocal: torch.Tensor, + offset: torch.Tensor, + bias: torch.Tensor, + ): """ Replacement for the AddRMSNormQuant fusion. """ - output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input, - residual, - rms_norm_weight, - scale, - offset, - epsilon=self.eps, - beta=bias) + output = torch.ops.npu.npu_add_rms_norm_quant( + rms_norm_input, residual, rms_norm_weight, scale, offset, epsilon=self.eps, beta=bias + ) quantized_output = output[0] out1 = output[2] - quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - quantized_output, True) + quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(quantized_output, True) return quantized_output, out1 - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) class AddRMSNormQuantFusionPass(VllmInductorPass): @@ -283,25 +271,19 @@ class AddRMSNormQuantFusionPass(VllmInductorPass): def __init__(self, vllm_config: VllmConfig): super().__init__(vllm_config) - self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass( - pass_name="rmsnorm_quant_fusion_pass") + self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(pass_name="rmsnorm_quant_fusion_pass") dtype = vllm_config.model_config.dtype if dtype not in (torch.bfloat16, torch.float16): - logger.debug("Quant fusion not enabled: unsupported dtype %s", - dtype) + logger.debug("Quant fusion not enabled: unsupported dtype %s", dtype) return common_epsilons = [1e-5, 1e-6] for eps in common_epsilons: - AddRMSNormQuantPattern(vllm_config, - eps=eps).register(self.pattern_match_passes) - AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register( - self.pattern_match_passes) - AddRMSNormQuantSPPattern(vllm_config, eps=eps).register( - self.pattern_match_passes) - AddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register( - self.pattern_match_passes) + AddRMSNormQuantPattern(vllm_config, eps=eps).register(self.pattern_match_passes) + AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes) + AddRMSNormQuantSPPattern(vllm_config, eps=eps).register(self.pattern_match_passes) + AddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes) def __call__(self, graph: torch.fx.Graph): self.begin() diff --git a/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py b/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py index ed90c7f8..f9dbf768 100644 --- a/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py +++ b/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py @@ -17,8 +17,7 @@ # import torch import torch._inductor.pattern_matcher as pm -from torch._inductor.pattern_matcher import (PatternMatcherPass, - PatternPrettyPrinter) +from torch._inductor.pattern_matcher import PatternMatcherPass, PatternPrettyPrinter from vllm.attention.layer import Attention from vllm.compilation.vllm_inductor_pass import VllmInductorPass from vllm.config import VllmConfig, get_layers_from_vllm_config @@ -27,13 +26,7 @@ from vllm.logger import logger class QKNormRopeFusionPattern: - - def __init__(self, - vllm_config, - head_dim, - num_heads, - num_kv_heads, - eps=1e-6): + def __init__(self, vllm_config, head_dim, num_heads, num_kv_heads, eps=1e-6): self.vllm_config = vllm_config self.head_dim = head_dim self.num_heads = num_heads @@ -45,65 +38,38 @@ class QKNormRopeFusionPattern: def get_inputs(self): T = 5 - qkv = torch.empty(T, - self.q_size + 2 * self.kv_size, - dtype=torch.bfloat16, - device="npu") - q_weight = torch.empty(self.head_dim, - dtype=torch.bfloat16, - device="npu") - k_weight = torch.empty(self.head_dim, - dtype=torch.bfloat16, - device="npu") - cos = torch.empty(1, - T, - 1, - self.head_dim, - dtype=torch.bfloat16, - device="npu") - sin = torch.empty(1, - T, - 1, - self.head_dim, - dtype=torch.bfloat16, - device="npu") + qkv = torch.empty(T, self.q_size + 2 * self.kv_size, dtype=torch.bfloat16, device="npu") + q_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") + k_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") + cos = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu") + sin = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu") return [qkv, q_weight, k_weight, cos, sin] def register(self, pm_pass: PatternMatcherPass): + def pattern( + qkv: torch.Tensor, q_weight: torch.Tensor, k_weight: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor + ): + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - def pattern(qkv: torch.Tensor, q_weight: torch.Tensor, - k_weight: torch.Tensor, cos: torch.Tensor, - sin: torch.Tensor): + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) + q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight, self.eps) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], - dim=-1) - - q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, - self.head_dim) - q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight, - self.eps) - - k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, - self.head_dim) - k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight, - self.eps) + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) + k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight, self.eps) q_flat = q_norm_out.view(q.shape) - q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, - self.head_dim) + q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim) k_flat = k_norm_out.view(k.shape) - k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, - self.head_dim) + k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim) - q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb( - q_reshape, k_reshape, cos, sin) + q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin) return q_rope, k_rope, v - def replacement(qkv: torch.Tensor, q_weight: torch.Tensor, - k_weight: torch.Tensor, cos: torch.Tensor, - sin: torch.Tensor): + def replacement( + qkv: torch.Tensor, q_weight: torch.Tensor, k_weight: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor + ): results = torch.ops.vllm.qkv_rmsnorm_rope( input=qkv, q_weight=q_weight, @@ -115,22 +81,16 @@ class QKNormRopeFusionPattern: q_bias=None, k_bias=None, sin=sin, - cos=cos) + cos=cos, + ) return results - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) class QKNormRopeFusionPatternWithBias: - - def __init__(self, - vllm_config, - head_dim, - num_heads, - num_kv_heads, - eps=1e-6): + def __init__(self, vllm_config, head_dim, num_heads, num_kv_heads, eps=1e-6): self.head_dim = head_dim self.num_heads = num_heads self.num_kv_heads = num_kv_heads @@ -142,71 +102,55 @@ class QKNormRopeFusionPatternWithBias: def get_inputs(self): T = 5 - qkv = torch.empty(T, - self.q_size + 2 * self.kv_size, - dtype=torch.bfloat16, - device="npu") - q_weight = torch.empty(self.head_dim, - dtype=torch.bfloat16, - device="npu") - k_weight = torch.empty(self.head_dim, - dtype=torch.bfloat16, - device="npu") + qkv = torch.empty(T, self.q_size + 2 * self.kv_size, dtype=torch.bfloat16, device="npu") + q_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") + k_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") q_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") k_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") - cos = torch.empty(1, - T, - 1, - self.head_dim, - dtype=torch.bfloat16, - device="npu") - sin = torch.empty(1, - T, - 1, - self.head_dim, - dtype=torch.bfloat16, - device="npu") + cos = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu") + sin = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu") return [qkv, q_weight, k_weight, q_bias, k_bias, cos, sin] def register(self, pm_pass: PatternMatcherPass): + def pattern( + qkv: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + q_bias: torch.Tensor, + k_bias: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ): + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - def pattern(qkv: torch.Tensor, q_weight: torch.Tensor, - k_weight: torch.Tensor, q_bias: torch.Tensor, - k_bias: torch.Tensor, cos: torch.Tensor, - sin: torch.Tensor): - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], - dim=-1) - - q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, - self.head_dim) - q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight, - self.eps) + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) + q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight, self.eps) q_normed = q_norm_out + q_bias - k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, - self.head_dim) - k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight, - self.eps) + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) + k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight, self.eps) k_normed = k_norm_out + k_bias q_flat = q_normed.view(q.shape) - q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, - self.head_dim) + q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim) k_flat = k_normed.view(k.shape) - k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, - self.head_dim) + k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim) - q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb( - q_reshape, k_reshape, cos, sin) + q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin) return q_rope, k_rope, v - def replacement(qkv: torch.Tensor, q_weight: torch.Tensor, - k_weight: torch.Tensor, q_bias: torch.Tensor, - k_bias: torch.Tensor, cos: torch.Tensor, - sin: torch.Tensor): + def replacement( + qkv: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + q_bias: torch.Tensor, + k_bias: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ): results = torch.ops.vllm.qkv_rmsnorm_rope( input=qkv, q_weight=q_weight, @@ -218,11 +162,11 @@ class QKNormRopeFusionPatternWithBias: q_bias=q_bias, k_bias=k_bias, cos=cos, - sin=sin) + sin=sin, + ) return results - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) class QKNormRopeFusionPass(VllmInductorPass): @@ -232,44 +176,38 @@ class QKNormRopeFusionPass(VllmInductorPass): def __init__(self, vllm_config: VllmConfig): super().__init__(vllm_config) - self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass( - pass_name="qknorm_rope_fusion_pass") + self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(pass_name="qknorm_rope_fusion_pass") dtype = vllm_config.model_config.dtype if dtype not in (torch.bfloat16, torch.float16): - logger.debug( - "QKNorm and Rope fusion not enabled: unsupported dtype %s", - dtype) + logger.debug("QKNorm and Rope fusion not enabled: unsupported dtype %s", dtype) return # use one attn layer to get meta (such as head_dim) for QKNormRopeFusionPattern - attn_layers: dict[str, Attention] = get_layers_from_vllm_config( - vllm_config, Attention) + attn_layers: dict[str, Attention] = get_layers_from_vllm_config(vllm_config, Attention) if len(attn_layers) == 0: - logger.debug( - "QKNorm and Rope fusion enabled, but no Attention layers were discovered." - ) + logger.debug("QKNorm and Rope fusion enabled, but no Attention layers were discovered.") return layer = next(iter(attn_layers.values())) for epsilon in [1e-6, 1e-5]: if layer.head_size != 128: - logger.debug( - "QKNorm and Rope fusion not enabled: head_dim %d is not equal of 128", - layer.head_size) + logger.debug("QKNorm and Rope fusion not enabled: head_dim %d is not equal of 128", layer.head_size) continue - QKNormRopeFusionPattern(vllm_config=vllm_config, - head_dim=layer.head_size, - num_heads=layer.num_heads, - num_kv_heads=layer.num_kv_heads, - eps=epsilon).register( - self.pattern_match_passes) + QKNormRopeFusionPattern( + vllm_config=vllm_config, + head_dim=layer.head_size, + num_heads=layer.num_heads, + num_kv_heads=layer.num_kv_heads, + eps=epsilon, + ).register(self.pattern_match_passes) - QKNormRopeFusionPatternWithBias(vllm_config=vllm_config, - head_dim=layer.head_size, - num_heads=layer.num_heads, - num_kv_heads=layer.num_kv_heads, - eps=epsilon).register( - self.pattern_match_passes) + QKNormRopeFusionPatternWithBias( + vllm_config=vllm_config, + head_dim=layer.head_size, + num_heads=layer.num_heads, + num_kv_heads=layer.num_kv_heads, + eps=epsilon, + ).register(self.pattern_match_passes) def __call__(self, graph: torch.fx.Graph): self.begin() diff --git a/vllm_ascend/cpu_binding.py b/vllm_ascend/cpu_binding.py index 9bcd5319..062ff4f9 100644 --- a/vllm_ascend/cpu_binding.py +++ b/vllm_ascend/cpu_binding.py @@ -1,10 +1,8 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- import os import subprocess from collections import defaultdict -from typing import Dict, List, Tuple import psutil from vllm.logger import logger @@ -13,26 +11,22 @@ ALLOWED_CPUS_PATH = "/proc/self/status" ASCEND_RT_VISIBLE_DEVICES = os.getenv("ASCEND_RT_VISIBLE_DEVICES") -def execute_command(cmd: List[str]) -> Tuple[str, int]: - with subprocess.Popen(cmd, - shell=False, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) as p: +def execute_command(cmd: list[str]) -> tuple[str, int]: + with subprocess.Popen(cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE) as p: out, _ = p.communicate(timeout=1000) return out.decode(), p.returncode class DeviceInfo: - def __init__(self): - self.npu_map_info: Dict[str, Dict[str, str]] = self.get_npu_map_info() - self.allowed_cpus: List[int] = self.parse_allowed_cpus() - self.running_npu_list: List[int] = self.get_running_npus() - self.npu_affinity: Dict[int, List[int]] = self.parse_topo_affinity() + self.npu_map_info: dict[str, dict[str, str]] = self.get_npu_map_info() + self.allowed_cpus: list[int] = self.parse_allowed_cpus() + self.running_npu_list: list[int] = self.get_running_npus() + self.npu_affinity: dict[int, list[int]] = self.parse_topo_affinity() @staticmethod - def expand_cpu_list(allowed_list_str: str) -> List[int]: - allowed_cpus_list: List[int] = [] + def expand_cpu_list(allowed_list_str: str) -> list[int]: + allowed_cpus_list: list[int] = [] for per_range in allowed_list_str.split(","): if "-" in per_range: start_cpu, end_cpu = map(int, per_range.split("-")) @@ -42,8 +36,8 @@ class DeviceInfo: return allowed_cpus_list @staticmethod - def get_npu_map_info() -> Dict[str, Dict[str, str]]: - npu_map_info: Dict[str, Dict[str, str]] = {} + def get_npu_map_info() -> dict[str, dict[str, str]]: + npu_map_info: dict[str, dict[str, str]] = {} npu_info, _ = execute_command(["npu-smi", "info", "-m"]) npu_map = npu_info.strip().split("\n")[1:] for line in npu_map: @@ -55,7 +49,7 @@ class DeviceInfo: npu_map_info[npu_id][chip_id] = chip_logic_id return npu_map_info - def get_running_npus(self) -> List[int]: + def get_running_npus(self) -> list[int]: npu_message, _ = execute_command(["npu-smi", "info"]) in_proc_section = False running_npu_set = set() @@ -76,36 +70,29 @@ class DeviceInfo: continue chip_logic_id = self.npu_map_info.get(npu_id, {}).get(chip_id) if not chip_logic_id or not chip_logic_id.isdigit(): - raise RuntimeError( - "Failed to get correct chip_logic_id from command 'npu-smi info -m'." - ) + raise RuntimeError("Failed to get correct chip_logic_id from command 'npu-smi info -m'.") running_npu_set.add(int(chip_logic_id)) if ASCEND_RT_VISIBLE_DEVICES: devices_str = ASCEND_RT_VISIBLE_DEVICES devices_list = [int(x) for x in devices_str.split(",")] running_npu_set = set(devices_list) & running_npu_set if not running_npu_set: - raise RuntimeError( - "Can not get running npu info, you can use BIND_CPU=0 to skip." - ) + raise RuntimeError("Can not get running npu info, you can use BIND_CPU=0 to skip.") return sorted(running_npu_set) - def parse_allowed_cpus(self) -> List[int]: + def parse_allowed_cpus(self) -> list[int]: if not os.path.exists(ALLOWED_CPUS_PATH): return [] with open(ALLOWED_CPUS_PATH) as f: for line in f: if line.startswith("Cpus_allowed_list"): return self.expand_cpu_list(line.split()[1]) - raise RuntimeError( - "Can not found specific 'Cpus_allowed_list' in the '/proc/self/status' file." - ) + raise RuntimeError("Can not found specific 'Cpus_allowed_list' in the '/proc/self/status' file.") - def parse_topo_affinity(self) -> Dict[int, List[int]]: + def parse_topo_affinity(self) -> dict[int, list[int]]: chip_logic_id = 0 - affinity: Dict[int, List[int]] = {} - affinity_message, _ = execute_command( - ["npu-smi", "info", "-t", "topo"]) + affinity: dict[int, list[int]] = {} + affinity_message, _ = execute_command(["npu-smi", "info", "-t", "topo"]) for line in affinity_message.splitlines(): if line.startswith("NPU"): parts = line.split() @@ -117,21 +104,19 @@ class DeviceInfo: class CpuAlloc: - def __init__(self, rank_id: int): self.rank_id = rank_id self.device_info: DeviceInfo = DeviceInfo() - self.cpu_node: Dict[int, int] = {} - self.numa_to_cpu_map: Dict[int, List[int]] = defaultdict(list) - self.npu_cpu_pool: Dict[int, List[int]] = {} - self.assign_main: Dict[int, List[int]] = {} - self.assign_acl: Dict[int, List[int]] = {} - self.assign_rel: Dict[int, List[int]] = {} + self.cpu_node: dict[int, int] = {} + self.numa_to_cpu_map: dict[int, list[int]] = defaultdict(list) + self.npu_cpu_pool: dict[int, list[int]] = {} + self.assign_main: dict[int, list[int]] = {} + self.assign_acl: dict[int, list[int]] = {} + self.assign_rel: dict[int, list[int]] = {} @staticmethod - def get_threads_map( - thread_message: str) -> Dict[str, Dict[str, List[str]]]: - threads_map: Dict[str, Dict[str, List[str]]] = {} + def get_threads_map(thread_message: str) -> dict[str, dict[str, list[str]]]: + threads_map: dict[str, dict[str, list[str]]] = {} for line in thread_message.splitlines(): parts = line.split() if len(parts) < 2: @@ -144,40 +129,33 @@ class CpuAlloc: else: continue if main_pid not in threads_map: - threads_map[main_pid] = { - "acl_thread": [], - "release_thread": [] - } + threads_map[main_pid] = {"acl_thread": [], "release_thread": []} threads_map[main_pid][key].append(sub_pid) return threads_map @staticmethod - def bind(pid: str, cpus: List[int], bind_sub_thread: bool) -> None: + def bind(pid: str, cpus: list[int], bind_sub_thread: bool) -> None: if cpus: cpu_list = ",".join(map(str, cpus)) if bind_sub_thread: - bind_result, return_code = execute_command( - ["taskset", "-acp", cpu_list, pid]) + bind_result, return_code = execute_command(["taskset", "-acp", cpu_list, pid]) else: - bind_result, return_code = execute_command( - ["taskset", "-cp", cpu_list, pid]) + bind_result, return_code = execute_command(["taskset", "-cp", cpu_list, pid]) if return_code != 0: raise RuntimeError(f"Failed to bind {pid} to CPU {cpu_list}.") - def average_distribute( - self, groups: Dict[str, List[int]]) -> Dict[int, List[int]]: - result: Dict[int, List[int]] = {} + def average_distribute(self, groups: dict[str, list[int]]) -> dict[int, list[int]]: + result: dict[int, list[int]] = {} for key, npu_list in groups.items(): cpu_list = sorted(self.npu_cpu_pool[npu_list[0]]) cpu_num_per_npu = len(cpu_list) // len(npu_list) for i, npu in enumerate(npu_list): start_index = i * cpu_num_per_npu - end_index = (i + 1) * cpu_num_per_npu if i < len( - npu_list) - 1 else len(cpu_list) + end_index = (i + 1) * cpu_num_per_npu if i < len(npu_list) - 1 else len(cpu_list) result[npu] = cpu_list[start_index:end_index] return result - def extend_numa(self, cpu_list: List[int]) -> List[int]: + def extend_numa(self, cpu_list: list[int]) -> list[int]: if not cpu_list: return [] nodes = {self.cpu_node[c] for c in cpu_list} @@ -203,9 +181,7 @@ class CpuAlloc: self.cpu_node[cpu] = node self.numa_to_cpu_map[node].append(cpu) if len(self.numa_to_cpu_map) == 0: - raise RuntimeError( - "lscpu command output error, no NUMA node available. Please check!" - ) + raise RuntimeError("lscpu command output error, no NUMA node available. Please check!") def handle_no_affinity(self) -> None: num_running_npu = len(self.device_info.running_npu_list) @@ -219,10 +195,7 @@ class CpuAlloc: index = 0 for node in sorted(self.numa_to_cpu_map): # Available CPUs on this NUMA (constrained by allowed_cpus) - cpus = [ - c for c in self.numa_to_cpu_map[node] - if c in self.device_info.allowed_cpus - ] + cpus = [c for c in self.numa_to_cpu_map[node] if c in self.device_info.allowed_cpus] if not cpus: continue # The actual number of NPUs to be allocated on this NUMA. @@ -251,19 +224,16 @@ class CpuAlloc: return for npu in self.device_info.running_npu_list: base_cpu_list = [ - cpu for cpu in self.device_info.npu_affinity.get(npu, []) - if cpu in self.device_info.allowed_cpus + cpu for cpu in self.device_info.npu_affinity.get(npu, []) if cpu in self.device_info.allowed_cpus ] if not base_cpu_list: - raise RuntimeError( - "CPUs available in 'Cpus_allowed_list' conflict with NUMA affinity." - ) + raise RuntimeError("CPUs available in 'Cpus_allowed_list' conflict with NUMA affinity.") extra_cpu_list = self.extend_numa(base_cpu_list) self.npu_cpu_pool[npu] = extra_cpu_list groups = defaultdict(list) for npu, cpus in self.npu_cpu_pool.items(): groups[str(cpus)].append(npu) - final: Dict[int, List[int]] = {} + final: dict[int, list[int]] = {} for key, npu_list in groups.items(): if len(npu_list) == 1: final[npu_list[0]] = self.npu_cpu_pool[npu_list[0]] @@ -279,8 +249,8 @@ class CpuAlloc: rel = [pool[-1]] else: raise RuntimeError( - "The number of CPUs is insufficient to bind to the NPUs. " - "Each NPU requires at least 3 CPUs.") + "The number of CPUs is insufficient to bind to the NPUs. Each NPU requires at least 3 CPUs." + ) self.assign_main[npu] = main self.assign_acl[npu] = acl self.assign_rel[npu] = rel @@ -290,10 +260,8 @@ class CpuAlloc: current_npu = self.device_info.running_npu_list[self.rank_id] main = " ".join(map(str, self.assign_main[current_npu])) acl = " ".join(map(str, self.assign_acl[current_npu])) - rel = str(self.assign_rel[current_npu] - ) if self.assign_rel[current_npu] else "" - logger.info( - f"NPU{current_npu}: main=[{main}] acl=[{acl}] release=[{rel}]") + rel = str(self.assign_rel[current_npu]) if self.assign_rel[current_npu] else "" + logger.info(f"NPU{current_npu}: main=[{main}] acl=[{acl}] release=[{rel}]") def bind_threads(self) -> None: thread_message, _ = execute_command(["ps", "-Te"]) @@ -303,8 +271,7 @@ class CpuAlloc: self.bind(main_pid, self.assign_main[current_npu], True) for acl_thread in threads_map.get(main_pid, {}).get("acl_thread", []): self.bind(acl_thread, self.assign_acl[current_npu], False) - for release_thread in threads_map.get(main_pid, - {}).get("release_thread", []): + for release_thread in threads_map.get(main_pid, {}).get("release_thread", []): self.bind(release_thread, self.assign_rel[current_npu], False) def run_all(self) -> None: diff --git a/vllm_ascend/flash_common3_context.py b/vllm_ascend/flash_common3_context.py index a579af90..3f6cf268 100644 --- a/vllm_ascend/flash_common3_context.py +++ b/vllm_ascend/flash_common3_context.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Optional import torch from vllm.model_executor.layers.linear import LinearBase @@ -7,26 +6,26 @@ from vllm.model_executor.layers.linear import LinearBase @dataclass class FlashCommon3Context: - gate: Optional[LinearBase] = None - topk_weights: Optional[torch.Tensor] = None - topk_ids: Optional[torch.Tensor] = None - row_idx: Optional[torch.Tensor] = None - shared_experts: Optional[torch.nn.Module] = None - shared_out: Optional[torch.Tensor] = None + gate: LinearBase | None = None + topk_weights: torch.Tensor | None = None + topk_ids: torch.Tensor | None = None + row_idx: torch.Tensor | None = None + shared_experts: torch.nn.Module | None = None + shared_out: torch.Tensor | None = None -_flash_common3_context: Optional[FlashCommon3Context] = None +_flash_common3_context: FlashCommon3Context | None = None -def get_flash_common3_context() -> Optional[FlashCommon3Context]: +def get_flash_common3_context() -> FlashCommon3Context | None: return _flash_common3_context def set_flash_common3_context( - topk_weights: Optional[torch.Tensor] = None, - topk_ids: Optional[torch.Tensor] = None, - shared_experts: Optional[torch.nn.Module] = None, - shared_out: Optional[torch.Tensor] = None, + topk_weights: torch.Tensor | None = None, + topk_ids: torch.Tensor | None = None, + shared_experts: torch.nn.Module | None = None, + shared_out: torch.Tensor | None = None, ): global _flash_common3_context if _flash_common3_context is None: diff --git a/vllm_ascend/meta_registration.py b/vllm_ascend/meta_registration.py index 9a58afd9..cc154798 100644 --- a/vllm_ascend/meta_registration.py +++ b/vllm_ascend/meta_registration.py @@ -46,17 +46,20 @@ def register_meta_if_necessary(ns: str, op_name: str, fn, overload: str = ""): if overload != "": op_name = op_name + "." + overload schema_to_find = ns + "::" + op_name - meta_impl_list = torch._C._dispatch_get_registrations_for_dispatch_key( - "Meta") + meta_impl_list = torch._C._dispatch_get_registrations_for_dispatch_key("Meta") if schema_to_find in meta_impl_list: return lib.impl(op_name, fn, "Meta") -def rotary_embedding_meta(positions: torch.Tensor, query: torch.Tensor, - key: torch.Tensor, head_size: int, - cos_sin_cache: torch.Tensor, is_neox: bool): - +def rotary_embedding_meta( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + head_size: int, + cos_sin_cache: torch.Tensor, + is_neox: bool, +): num_tokens = positions.numel() query_hidden_size = query.numel() // num_tokens key_hidden_size = key.numel() // num_tokens @@ -68,38 +71,41 @@ def rotary_embedding_meta(positions: torch.Tensor, query: torch.Tensor, return query_dst, key_dst -def get_masked_input_and_mask_meta(input: torch.Tensor, - org_vocab_start_index: int, - org_vocab_end_index: int, - num_org_vocab_padding: int, - added_vocab_start_index: int, - added_vocab_end_index: int): - +def get_masked_input_and_mask_meta( + input: torch.Tensor, + org_vocab_start_index: int, + org_vocab_end_index: int, + num_org_vocab_padding: int, + added_vocab_start_index: int, + added_vocab_end_index: int, +): masked_input = torch.empty_like(input) mask = torch.empty_like(input).to(torch.bool) return masked_input, mask -def bgmv_expand_meta(x: torch.Tensor, weight: torch.Tensor, - indices: torch.Tensor, y: torch.Tensor, slice_offset: int, - slice_size: int): - +def bgmv_expand_meta( + x: torch.Tensor, weight: torch.Tensor, indices: torch.Tensor, y: torch.Tensor, slice_offset: int, slice_size: int +): y_out = torch.empty_like(y) return y_out -def sgmv_expand_meta(x: torch.Tensor, weight: torch.Tensor, - lora_indices: torch.Tensor, seq_len: torch.Tensor, - y: torch.Tensor, slice_offset: int, slice_size: int): - +def sgmv_expand_meta( + x: torch.Tensor, + weight: torch.Tensor, + lora_indices: torch.Tensor, + seq_len: torch.Tensor, + y: torch.Tensor, + slice_offset: int, + slice_size: int, +): y_out = torch.empty_like(y) return y_out -register_meta_if_necessary("_C_ascend", "rotary_embedding", - rotary_embedding_meta) -register_meta_if_necessary("_C_ascend", "get_masked_input_and_mask", - get_masked_input_and_mask_meta) +register_meta_if_necessary("_C_ascend", "rotary_embedding", rotary_embedding_meta) +register_meta_if_necessary("_C_ascend", "get_masked_input_and_mask", get_masked_input_and_mask_meta) register_meta_if_necessary("_C_ascend", "bgmv_expand", bgmv_expand_meta) register_meta_if_necessary("_C_ascend", "sgmv_expand", sgmv_expand_meta) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index eee64fce..3ca49187 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -15,9 +15,11 @@ # This file is a part of the vllm-ascend project. # +from __future__ import annotations + import math import os -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from uuid import uuid4 import torch @@ -32,11 +34,21 @@ from vllm_ascend.ascend_config import init_ascend_config # isort: off from vllm_ascend.utils import ( - ASCEND_QUANTIZATION_METHOD, COMPILATION_PASS_KEY, - COMPRESSED_TENSORS_METHOD, AscendDeviceType, check_kv_extra_config, - enable_sp, flashcomm2_enable, get_ascend_device_type, is_moe_model, - is_vl_model, refresh_block_size, update_aclgraph_sizes, - update_cudagraph_capture_sizes, update_default_aclgraph_sizes) + ASCEND_QUANTIZATION_METHOD, + COMPILATION_PASS_KEY, + COMPRESSED_TENSORS_METHOD, + AscendDeviceType, + check_kv_extra_config, + enable_sp, + flashcomm2_enable, + get_ascend_device_type, + is_moe_model, + is_vl_model, + refresh_block_size, + update_aclgraph_sizes, + update_cudagraph_capture_sizes, + update_default_aclgraph_sizes, +) if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig @@ -80,7 +92,6 @@ def config_deprecated_logging(): class NPUPlatform(Platform): - _enum = PlatformEnum.OOT device_name: str = "npu" device_type: str = "npu" @@ -89,9 +100,7 @@ class NPUPlatform(Platform): device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES" dispatch_key: str = "PrivateUse1" - supported_quantization: list[str] = [ - ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD - ] + supported_quantization: list[str] = [ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD] def is_sleep_mode_available(self) -> bool: return True @@ -116,33 +125,29 @@ class NPUPlatform(Platform): @classmethod def get_compile_backend(self) -> str: """ - Get the custom compile backend. Previously, we used EagerAdaptor by default. + Get the custom compile backend. Previously, we used EagerAdaptor by default. To use graph fusion operations, we defined our own backend compiler. """ return "vllm_ascend.compilation.compiler_interface.AscendCompiler" @classmethod - def pre_register_and_update(cls, - parser: Optional[FlexibleArgumentParser] = None - ) -> None: + def pre_register_and_update(cls, parser: FlexibleArgumentParser | None = None) -> None: # Adapt the global patch here. from vllm_ascend.utils import adapt_patch + adapt_patch(is_global_patch=True) # For online serving, "ascend" quantization method is not a choice natively, # so we need to add "ascend" quantization method to quantization methods list # and the user can enable quantization using "vllm serve --quantization ascend". if parser is not None: - quant_action = parser._option_string_actions.get('--quantization') - if quant_action and hasattr(quant_action, - 'choices') and quant_action.choices: + quant_action = parser._option_string_actions.get("--quantization") + if quant_action and hasattr(quant_action, "choices") and quant_action.choices: if ASCEND_QUANTIZATION_METHOD not in quant_action.choices: quant_action.choices.append(ASCEND_QUANTIZATION_METHOD) - from vllm_ascend.quantization.compressed_tensors.compressed_tensors import \ - AscendCompressedTensorsConfig # noqa: F401 - from vllm_ascend.quantization.quant_config import \ - AscendQuantConfig # noqa: F401 + from vllm_ascend.quantization.compressed_tensors.compressed_tensors import AscendCompressedTensorsConfig # noqa: F401 + from vllm_ascend.quantization.quant_config import AscendQuantConfig # noqa: F401 config_deprecated_logging() @@ -169,8 +174,7 @@ class NPUPlatform(Platform): if vllm_config.kv_transfer_config is not None: check_kv_extra_config(vllm_config) - if not getattr(vllm_config.kv_transfer_config, - "_engine_id_patched", False): + if not getattr(vllm_config.kv_transfer_config, "_engine_id_patched", False): vllm_config.kv_transfer_config.engine_id = f"{vllm_config.kv_transfer_config.engine_id}-{uuid4().hex}" vllm_config.kv_transfer_config._engine_id_patched = True from vllm.config import CompilationMode # noqa: E402 @@ -181,24 +185,22 @@ class NPUPlatform(Platform): cache_config = vllm_config.cache_config ascend_compilation_config = ascend_config.ascend_compilation_config if ascend_compilation_config: - vllm_config.additional_config.setdefault( - "ascend_compilation_config", {}).update( - vars(ascend_compilation_config - ) if not isinstance(ascend_compilation_config, dict) - else ascend_compilation_config) + vllm_config.additional_config.setdefault("ascend_compilation_config", {}).update( + vars(ascend_compilation_config) + if not isinstance(ascend_compilation_config, dict) + else ascend_compilation_config + ) - elif model_config and hasattr(model_config.hf_text_config, - "index_topk"): - vllm_config.cache_config.cache_dtype = str( - model_config.dtype).replace("torch.", "") + elif model_config and hasattr(model_config.hf_text_config, "index_topk"): + vllm_config.cache_config.cache_dtype = str(model_config.dtype).replace("torch.", "") if model_config is None: - logger.warning("Model config is missing. This may indicate " - "that we are running a test case") + logger.warning("Model config is missing. This may indicate that we are running a test case") enforce_eager = False else: enforce_eager = getattr(model_config, "enforce_eager", False) from vllm.config.compilation import CUDAGraphMode + if enforce_eager: logger.info("Compilation disabled, using eager mode by default") compilation_config.mode = CompilationMode.NONE @@ -207,12 +209,10 @@ class NPUPlatform(Platform): compilation_config.cudagraph_num_of_warmups = 1 - if compilation_config.mode not in [ - CompilationMode.NONE, CompilationMode.VLLM_COMPILE - ]: + if compilation_config.mode not in [CompilationMode.NONE, CompilationMode.VLLM_COMPILE]: logger.warning( - "NPU does not support %s compilation mode. Setting CUDAGraphMode to NONE", - compilation_config.mode) + "NPU does not support %s compilation mode. Setting CUDAGraphMode to NONE", compilation_config.mode + ) compilation_config.cudagraph_mode = CUDAGraphMode.NONE # set cudaprah sizes before extending `compilation_config.splitting_ops` @@ -223,15 +223,18 @@ class NPUPlatform(Platform): update_default_aclgraph_sizes(vllm_config) # TODO delete graph size update here when compilation_config.pass_config.enable_sp # is supported by vllm-ascend. - if vllm_config.parallel_config.tensor_parallel_size > 1 and not vllm_config.model_config.enforce_eager and \ - enable_sp(vllm_config): + if ( + vllm_config.parallel_config.tensor_parallel_size > 1 + and not vllm_config.model_config.enforce_eager + and enable_sp(vllm_config) + ): original_sizes = compilation_config.cudagraph_capture_sizes - sp_aclgraph_sizes = \ - vllm_config.update_sizes_for_sequence_parallelism(original_sizes) + sp_aclgraph_sizes = vllm_config.update_sizes_for_sequence_parallelism(original_sizes) assert sp_aclgraph_sizes, ( f"cudagraph_capture_sizes {original_sizes} does not contain" f"values that are multiples of tp_size " - f"{vllm_config.parallel_config.tensor_parallel_size}") + f"{vllm_config.parallel_config.tensor_parallel_size}" + ) if len(sp_aclgraph_sizes) != len(original_sizes): compilation_config.cudagraph_capture_sizes = sp_aclgraph_sizes update_cudagraph_capture_sizes(vllm_config, sp_aclgraph_sizes) @@ -243,9 +246,7 @@ class NPUPlatform(Platform): # encoder-decoder models currently only support piecewise mode if model_config and model_config.is_encoder_decoder is True: if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: - logger.warning( - "encoder-decoder model doesn't support FULL_DECODE_ONLY, fallback to PIECEWISE " - ) + logger.warning("encoder-decoder model doesn't support FULL_DECODE_ONLY, fallback to PIECEWISE ") compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE # get custom compile backend for graph fusion @@ -255,15 +256,14 @@ class NPUPlatform(Platform): compilation_config.mode = CompilationMode.NONE ascend_config.enable_npugraph_ex = False elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE: - logger.info( - "PIECEWISE compilation enabled on NPU. use_inductor not supported - " - "using only ACL Graph mode") - assert compilation_config.mode == CompilationMode.VLLM_COMPILE, \ - "When enabling VLLM_COMPILE aclgraph, please make sure compilation_config.mode == CompilationMode.VLLM_COMPILE and compilation_config.cudagraph_mode == CUDAGraphMode.VLLM_COMPILE" + logger.info("PIECEWISE compilation enabled on NPU. use_inductor not supported - using only ACL Graph mode") + assert compilation_config.mode == CompilationMode.VLLM_COMPILE, ( + "When enabling VLLM_COMPILE aclgraph, please make sure compilation_config.mode == " + "CompilationMode.VLLM_COMPILE and compilation_config.cudagraph_mode == CUDAGraphMode.VLLM_COMPILE" + ) compilation_config.set_splitting_ops_for_v1( all2all_backend=vllm_config.parallel_config.all2all_backend, - data_parallel_size=vllm_config.parallel_config. - data_parallel_size, + data_parallel_size=vllm_config.parallel_config.data_parallel_size, ) compilation_config.use_inductor = False # NOTE: Theoretically, we should also add vllm::mla_forward in the attention ops. @@ -275,11 +275,13 @@ class NPUPlatform(Platform): compilation_config.splitting_ops.extend(["vllm::mla_forward"]) update_aclgraph_sizes(vllm_config) ascend_config.enable_npugraph_ex = False - elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY or\ - compilation_config.cudagraph_mode == CUDAGraphMode.FULL: + elif ( + compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY + or compilation_config.cudagraph_mode == CUDAGraphMode.FULL + ): logger.info( - "FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - " - "using only ACL Graph mode") + "FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - using only ACL Graph mode" + ) compilation_config.use_inductor = False compilation_config.splitting_ops = [] warning_message = """\033[91m @@ -297,30 +299,31 @@ class NPUPlatform(Platform): logger.warning(warning_message) else: logger.info( - "%s cudagraph_mode is not support on NPU. falling back to NONE", - compilation_config.cudagraph_mode) + "%s cudagraph_mode is not support on NPU. falling back to NONE", compilation_config.cudagraph_mode + ) compilation_config.cudagraph_mode = CUDAGraphMode.NONE compilation_config.mode = CompilationMode.NONE ascend_config.enable_npugraph_ex = False # TODO: Remove this check when ACL Graph supports ASCEND_LAUNCH_BLOCKING=1 # Then, we will have to discuss the error handling strategy and user experience - if compilation_config.cudagraph_mode != CUDAGraphMode.NONE and \ - os.environ.get("ASCEND_LAUNCH_BLOCKING", "0") == "1": + if ( + compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and os.environ.get("ASCEND_LAUNCH_BLOCKING", "0") == "1" + ): raise ValueError( "ACL graph is incompatible with ASCEND_LAUNCH_BLOCKING=1. " "Please unset ASCEND_LAUNCH_BLOCKING or set it to 0. If you " "need ASCEND_LAUNCH_BLOCKING for debugging, consider other methods — " "for example, check the plog files (default: $HOME/ascend/log/debug) " - "for more information about runtime errors.") + "for more information about runtime errors." + ) if parallel_config and parallel_config.worker_cls == "auto": # TODO: this is a tricky way to disable `use_sequence_parallel_moe` in vllm. parallel_config.all2all_backend = "flashinfer_all2allv" if ascend_config.xlite_graph_config.enabled: - logger.info( - "openEuler Xlite enabled. See: https://atomgit.com/openeuler/GVirt/tree/master/xlite" - ) + logger.info("openEuler Xlite enabled. See: https://atomgit.com/openeuler/GVirt/tree/master/xlite") parallel_config.worker_cls = "vllm_ascend.xlite.xlite_worker.XliteWorker" else: parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker" @@ -332,10 +335,9 @@ class NPUPlatform(Platform): compilation_config.custom_ops = ["all"] if ascend_config.recompute_scheduler_enable: - from vllm_ascend.core.recompute_scheduler import \ - RecomputeSchedulerConfig - recompute_scheduler_config = RecomputeSchedulerConfig.initialize_from_config( - vllm_config) + from vllm_ascend.core.recompute_scheduler import RecomputeSchedulerConfig + + recompute_scheduler_config = RecomputeSchedulerConfig.initialize_from_config(vllm_config) vllm_config.scheduler_config = recompute_scheduler_config # Extend original scheduler_config to use SchedulerDynamicBatch. @@ -346,9 +348,11 @@ class NPUPlatform(Platform): vllm_config.scheduler_config.enable_chunked_prefill = True vllm_config.scheduler_config.SLO_limits_for_dynamic_batch = ascend_config.SLO_limits_for_dynamic_batch - if vllm_config.kv_transfer_config is not None and \ - cache_config.block_size != parallel_config.cp_kv_cache_interleave_size and \ - parallel_config.decode_context_parallel_size * parallel_config.prefill_context_parallel_size > 1: + if ( + vllm_config.kv_transfer_config is not None + and cache_config.block_size != parallel_config.cp_kv_cache_interleave_size + and parallel_config.decode_context_parallel_size * parallel_config.prefill_context_parallel_size > 1 + ): raise AssertionError( f"cp_kv_cache_interleave_size({parallel_config.cp_kv_cache_interleave_size}) " f"and block_size({cache_config.block_size}) " @@ -356,12 +360,14 @@ class NPUPlatform(Platform): ) if is_vl_model(vllm_config): - if bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))) or \ - bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", '0'))): + if bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0"))) or bool( + int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", "0")) + ): raise ValueError( "Currently, VL models doesn't support " "FLASHCOMM in vllm-ascend. We will fix this in the future. " - "Please set VLLM_ASCEND_ENABLE_FLASHCOMM1=0.") + "Please set VLLM_ASCEND_ENABLE_FLASHCOMM1=0." + ) @classmethod def import_kernels(cls) -> None: @@ -377,14 +383,11 @@ class NPUPlatform(Platform): if _CUSTOM_OP_REGISTERED: return CUR_DIR = os.path.dirname(os.path.realpath(__file__)) - CUSTOM_OPP_PATH = os.path.join(CUR_DIR, "_cann_ops_custom", "vendors", - "vllm-ascend") + CUSTOM_OPP_PATH = os.path.join(CUR_DIR, "_cann_ops_custom", "vendors", "vllm-ascend") if os.path.exists(CUSTOM_OPP_PATH): - current_cust_opp_path = os.environ.get("ASCEND_CUSTOM_OPP_PATH", - "") + current_cust_opp_path = os.environ.get("ASCEND_CUSTOM_OPP_PATH", "") if current_cust_opp_path: - os.environ[ - "ASCEND_CUSTOM_OPP_PATH"] = f"{CUSTOM_OPP_PATH}:{current_cust_opp_path}" + os.environ["ASCEND_CUSTOM_OPP_PATH"] = f"{CUSTOM_OPP_PATH}:{current_cust_opp_path}" else: os.environ["ASCEND_CUSTOM_OPP_PATH"] = CUSTOM_OPP_PATH _CUSTOM_OP_REGISTERED = True @@ -393,22 +396,18 @@ class NPUPlatform(Platform): def get_attn_backend_cls(cls, selected_backend, attn_selector_config): backend_map = { (True, False): "vllm_ascend.attention.mla_v1.AscendMLABackend", - (False, False): - "vllm_ascend.attention.attention_v1.AscendAttentionBackend", + (False, False): "vllm_ascend.attention.attention_v1.AscendAttentionBackend", (True, True): "vllm_ascend.attention.sfa_v1.AscendSFABackend", } - return backend_map[(attn_selector_config.use_mla, - attn_selector_config.use_sparse)] + return backend_map[(attn_selector_config.use_mla, attn_selector_config.use_sparse)] @classmethod def get_punica_wrapper(cls) -> str: return "vllm_ascend.lora.punica_npu.PunicaWrapperNPU" @classmethod - def get_current_memory_usage(cls, - device: Optional[torch.types.Device] = None - ) -> float: + def get_current_memory_usage(cls, device: torch.types.Device | None = None) -> float: torch.npu.reset_peak_memory_stats(device) return torch.npu.max_memory_allocated(device) @@ -457,32 +456,33 @@ class NPUPlatform(Platform): Args: attn_metadata (dict[str, Any]): attention metadata for all layers. vllm_config (VllmConfig): configuration of vllm. - dp_metadata (DpMetada): metadata for data parallelism. + dp_metadata (DpMetada): metadata for data parallelism. lack of typehint because of circular import. virtual_engine (int, optional): index of virtual engine. Defaults to 0. num_tokens (int | None, optional): number of tokens. Defaults to None. - num_tokens_across_dp (torch.Tensor | None, optional): number of tokens + num_tokens_across_dp (torch.Tensor | None, optional): number of tokens across data parallelism.Defaults to None. cudagraph_runtime_mode (CUDAGraphMode, optional): mode of cudagraph runtime. Defaults to None.lack of typehint because of circular import. batch_descriptor (BatchDescriptor, optional): descriptor of batch. Defaults to None. - ubatch_slices (UBatchSlices, optional): slice info for dual batch. + ubatch_slices (UBatchSlices, optional): slice info for dual batch. Defaults to None. lack of typehint because of circular import Returns: dict[str, Any]: _description_ """ # NOTE(Ronald1995): avoid circular import. - from vllm_ascend.ascend_forward_context import (get_mc2_mask, - select_moe_comm_method) + from vllm_ascend.ascend_forward_context import get_mc2_mask, select_moe_comm_method from vllm_ascend.ops.fused_moe.moe_comm_method import get_moe_comm_method from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size + # NOTE(Ronald1995): avoid circular import, cudagraph_runtime_mode is # CUDAGraphMode.NONE in vllm, but we can't set CUDAGraphMode.NONE in # argument default value, so we set it to None first, then set it to # CUDAGraphMode.NONE here. from vllm.config import CUDAGraphMode + if cudagraph_runtime_mode is None: cudagraph_runtime_mode = CUDAGraphMode.NONE # TODO(Ronald1995): model runner v1 still use ascend_forward_context, @@ -526,31 +526,26 @@ class NPUPlatform(Platform): sp_enabled = enable_sp(vllm_config) and num_tokens is not None mmrs_fusion = False else: - sp_enabled = enable_sp(vllm_config) and \ - num_tokens is not None and num_tokens > 1000 + sp_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000 # TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2 - flashcomm_v2_enabled = flashcomm2_enable( - ) and tp_world_size > 1 and num_tokens is not None + flashcomm_v2_enabled = flashcomm2_enable() and tp_world_size > 1 and num_tokens is not None pad_size = 0 - if (sp_enabled or flashcomm_v2_enabled): - pad_size = (tp_world_size - - (num_tokens % tp_world_size)) % tp_world_size + if sp_enabled or flashcomm_v2_enabled: + pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size dp_world_size = get_dp_group().world_size if dp_world_size > 1 and dp_metadata is not None: max_tokens_across_dp = dp_metadata.max_tokens_across_dp_cpu.item() - if (sp_enabled or flashcomm_v2_enabled): - padded_length = (max_tokens_across_dp + tp_world_size - - 1) // tp_world_size * tp_world_size + if sp_enabled or flashcomm_v2_enabled: + padded_length = (max_tokens_across_dp + tp_world_size - 1) // tp_world_size * tp_world_size pad_size = padded_length - num_tokens else: max_tokens_across_dp = num_tokens if num_tokens is not None: # NOTE: token num which need to pad to when mc2 - padded_num_tokens = math.ceil( - max_tokens_across_dp / tp_world_size) * tp_world_size + padded_num_tokens = math.ceil(max_tokens_across_dp / tp_world_size) * tp_world_size reserved_mc2_mask = get_mc2_mask() if reserved_mc2_mask is not None: mc2_mask = reserved_mc2_mask[:padded_num_tokens] diff --git a/vllm_ascend/profiling_config.py b/vllm_ascend/profiling_config.py index 8e0dfadf..f5c25dfc 100644 --- a/vllm_ascend/profiling_config.py +++ b/vllm_ascend/profiling_config.py @@ -21,9 +21,9 @@ This module generates the service_profiling_symbols.yaml configuration file to ~/.config/vllm_ascend/ directory. """ +import contextlib import tempfile from pathlib import Path -from typing import Optional import vllm from vllm.logger import logger @@ -120,7 +120,7 @@ SERVICE_PROFILING_SYMBOLS_YAML = """ def get_config_dir() -> Path: """ Get the vllm_ascend configuration directory path. - + Returns: Path: The path to ~/.config/vllm_ascend/ directory. """ @@ -129,32 +129,30 @@ def get_config_dir() -> Path: return config_dir -def _cleanup_temp_file(tmp_path: Optional[Path]) -> None: +def _cleanup_temp_file(tmp_path: Path | None) -> None: """ Clean up a temporary file if it exists. - + Args: tmp_path: Path to the temporary file to clean up. """ if tmp_path is not None and tmp_path.exists(): - try: + with contextlib.suppress(OSError): tmp_path.unlink() - except OSError: - pass # Ignore cleanup errors -def generate_service_profiling_config() -> Optional[Path]: +def generate_service_profiling_config() -> Path | None: """ Generate the service_profiling_symbols.yaml configuration file to ~/.config/vllm_ascend/ directory. - + If the configuration file already exists, this function will skip creating it and return the existing file path. - + If any error occurs during file creation, it will be logged but will not interrupt the execution. The function will return None to indicate that the file could not be created. - + Returns: Optional[Path]: The path to the generated (or existing) configuration file. Returns None if file creation failed. @@ -170,9 +168,7 @@ def generate_service_profiling_config() -> Optional[Path]: try: config_dir.mkdir(parents=True, exist_ok=True) except (OSError, PermissionError) as e: - logger.error( - f"Failed to create configuration directory {config_dir}: {e}", - exc_info=True) + logger.error(f"Failed to create configuration directory {config_dir}: {e}", exc_info=True) return None # Write the configuration file atomically using a temporary file @@ -180,13 +176,9 @@ def generate_service_profiling_config() -> Optional[Path]: tmp_path = None try: # Create a temporary file in the same directory for atomic write - with tempfile.NamedTemporaryFile(mode='w', - encoding='utf-8', - dir=config_dir, - delete=False, - suffix='.tmp', - prefix=CONFIG_FILENAME + - '.') as tmp_file: + with tempfile.NamedTemporaryFile( + mode="w", encoding="utf-8", dir=config_dir, delete=False, suffix=".tmp", prefix=CONFIG_FILENAME + "." + ) as tmp_file: tmp_file.write(SERVICE_PROFILING_SYMBOLS_YAML) tmp_path = Path(tmp_file.name) @@ -194,8 +186,7 @@ def generate_service_profiling_config() -> Optional[Path]: tmp_path.replace(config_file) return config_file except (OSError, PermissionError) as e: - logger.error(f"Failed to write configuration file {config_file}: {e}", - exc_info=True) + logger.error(f"Failed to write configuration file {config_file}: {e}", exc_info=True) return None finally: # Clean up the temporary file if it wasn't successfully replaced diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 0cdd6067..ca5bb41f 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -17,6 +17,8 @@ # Adapted from vllm-project/vllm/vllm/worker/worker.py # +from __future__ import annotations + import atexit import functools import math @@ -25,7 +27,7 @@ from contextlib import contextmanager, nullcontext from enum import Enum from functools import lru_cache from threading import Lock -from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any import torch import torch_npu # noqa: F401 @@ -88,6 +90,7 @@ def acl_graph_print(*args): resolving unexpected hangs. Usage: ```python from vllm_ascend.utils import acl_graph_print + ... acl_graph_print("Debug info") ``` @@ -113,8 +116,7 @@ def acl_graph_print(*args): torch_npu.npu._subscribe_report(current_compute_stream) _SUBSCRIBED_COMPUTE_STREAMS.add(current_compute_stream) - torch_npu.npu._launch_host_func(current_compute_stream, - _print_callback_on_stream, args) + torch_npu.npu._launch_host_func(current_compute_stream, _print_callback_on_stream, args) def _unregister_print_streams_on_exit(): @@ -196,9 +198,7 @@ def nd_to_nz_2d(in_tensor: torch.Tensor) -> torch.Tensor: # after: pad_dims: [0, 2, 0, 3] # return: (1, 2, 16, 16) - return _custom_transpose( - _custom_reshape(_custom_pad(in_tensor, pad_dims), aux_dims), 1, - 2).contiguous() + return _custom_transpose(_custom_reshape(_custom_pad(in_tensor, pad_dims), aux_dims), 1, 2).contiguous() def nd_to_nz_spec(mask_tensor: torch.Tensor) -> torch.Tensor: @@ -208,11 +208,9 @@ def nd_to_nz_spec(mask_tensor: torch.Tensor) -> torch.Tensor: tokens_pad = (num_tokens + 15) // 16 * 16 max_seq_len_pad = (max_seq_len + 15) // 16 * 16 - mask_tensor_pad = \ - torch.zeros((1, tokens_pad, max_seq_len_pad), dtype=mask_tensor.dtype, device=mask_tensor.device) + mask_tensor_pad = torch.zeros((1, tokens_pad, max_seq_len_pad), dtype=mask_tensor.dtype, device=mask_tensor.device) mask_tensor_pad[0][:num_tokens, :max_seq_len] = mask_tensor - mask = mask_tensor_pad.reshape( - (1, tokens_pad, max_seq_len_pad // 16, 16)).permute(0, 2, 1, 3) + mask = mask_tensor_pad.reshape((1, tokens_pad, max_seq_len_pad // 16, 16)).permute(0, 2, 1, 3) return mask @@ -230,10 +228,7 @@ def aligned_16(tensor: torch.Tensor): return tensor # Create a new tensor with shape (n_aligned, H, W) and fill it with zeros - new_tensor = torch.zeros(n_aligned, - *tensor.shape[1:], - dtype=tensor.dtype, - device=tensor.device) + new_tensor = torch.zeros(n_aligned, *tensor.shape[1:], dtype=tensor.dtype, device=tensor.device) # Copy the original tensor to the first N positions of the new tensor new_tensor[:n] = tensor @@ -254,15 +249,15 @@ def enable_custom_op(): # isort: off # register custom ops into torch_library here import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401 + # register the meta implementation for custom kernel if necessary import vllm_ascend.meta_registration # type: ignore # noqa: F401 + # isort: on _CUSTOM_OP_ENABLED = True except ImportError: _CUSTOM_OP_ENABLED = False - logger.warning( - "Warning: Failed to register custom ops, all custom ops will be disabled" - ) + logger.warning("Warning: Failed to register custom ops, all custom ops will be disabled") return _CUSTOM_OP_ENABLED @@ -277,8 +272,7 @@ def find_hccl_library() -> str: # manually load the hccl library if so_file: - logger.info("Found hccl from environment variable HCCL_SO_PATH=%s", - so_file) + logger.info("Found hccl from environment variable HCCL_SO_PATH=%s", so_file) else: if torch.version.cann is not None: so_file = "libhccl.so" @@ -318,6 +312,7 @@ def set_weight_prefetch_method(weight_prefetch_config: WeightPrefetchConfig): global _WEIGHT_PREFETCH_METHOD if _WEIGHT_PREFETCH_METHOD is None: from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod + _WEIGHT_PREFETCH_METHOD = WeightPrefetchMethod(weight_prefetch_config) return _WEIGHT_PREFETCH_METHOD @@ -364,6 +359,7 @@ def vllm_version_is(target_vllm_version: str): vllm_version = envs_ascend.VLLM_VERSION else: import vllm + vllm_version = vllm.__version__ try: return Version(vllm_version) == Version(target_vllm_version) @@ -372,7 +368,8 @@ def vllm_version_is(target_vllm_version: str): f"Invalid vllm version {vllm_version} found. A dev version of vllm " "is installed probably. Set the environment variable VLLM_VERSION " "to control it by hand. And please make sure the value follows the " - "format of x.y.z.") + "format of x.y.z." + ) def get_max_hidden_layers(hf_config) -> int: @@ -394,20 +391,19 @@ def get_max_hidden_layers(hf_config) -> int: # Update cudagraph capture sizes for vllm config -def update_cudagraph_capture_sizes(vllm_config: VllmConfig, - cudagraph_capture_sizes: List[int]): - - valid_max_size = (cudagraph_capture_sizes[-1] - if cudagraph_capture_sizes else 0) - if (vllm_config.compilation_config.max_cudagraph_capture_size is not None - and vllm_config.compilation_config.max_cudagraph_capture_size - != valid_max_size): +def update_cudagraph_capture_sizes(vllm_config: VllmConfig, cudagraph_capture_sizes: list[int]): + valid_max_size = cudagraph_capture_sizes[-1] if cudagraph_capture_sizes else 0 + if ( + vllm_config.compilation_config.max_cudagraph_capture_size is not None + and vllm_config.compilation_config.max_cudagraph_capture_size != valid_max_size + ): if vllm_config.compilation_config.cudagraph_capture_sizes is not None: raise ValueError( "customized max_cudagraph_capture_size" f"(={vllm_config.compilation_config.max_cudagraph_capture_size}) " "should be consistent with the max value of " - f"cudagraph_capture_sizes(={valid_max_size})") + f"cudagraph_capture_sizes(={valid_max_size})" + ) logger.warning( "Truncating max_cudagraph_capture_size to %d", valid_max_size, @@ -415,12 +411,11 @@ def update_cudagraph_capture_sizes(vllm_config: VllmConfig, vllm_config.compilation_config.max_cudagraph_capture_size = valid_max_size - if vllm_config.compilation_config.cudagraph_capture_sizes is not None and len( - cudagraph_capture_sizes) < len( - vllm_config.compilation_config.cudagraph_capture_sizes): + if vllm_config.compilation_config.cudagraph_capture_sizes is not None and len(cudagraph_capture_sizes) < len( + vllm_config.compilation_config.cudagraph_capture_sizes + ): logger.warning( - ("cudagraph_capture_sizes specified in compilation_config" - " %s is overridden by config %s"), + ("cudagraph_capture_sizes specified in compilation_config %s is overridden by config %s"), vllm_config.compilation_config.cudagraph_capture_sizes, cudagraph_capture_sizes, ) @@ -433,26 +428,17 @@ def _is_default_capture_sizes(vllm_config: VllmConfig) -> bool: Check whether it is vLLM default capture sizes. """ - max_cudagraph_capture_size = \ - vllm_config.compilation_config.max_cudagraph_capture_size - cudagraph_capture_sizes = [ - i for i in [1, 2, 4] if i <= max_cudagraph_capture_size - ] + max_cudagraph_capture_size = vllm_config.compilation_config.max_cudagraph_capture_size + cudagraph_capture_sizes = [i for i in [1, 2, 4] if i <= max_cudagraph_capture_size] if max_cudagraph_capture_size >= 8: # Step size 8 for small batch sizes, up to 256(not included) - cudagraph_capture_sizes += list( - range(8, min(max_cudagraph_capture_size + 1, 256), 8)) + cudagraph_capture_sizes += list(range(8, min(max_cudagraph_capture_size + 1, 256), 8)) if max_cudagraph_capture_size >= 256: # Step size 16 for larger batch sizes - cudagraph_capture_sizes += list( - range(256, max_cudagraph_capture_size + 1, 16)) + cudagraph_capture_sizes += list(range(256, max_cudagraph_capture_size + 1, 16)) # in newer version, vLLM use ascending order of cudagraph_capture_sizes. target_cudagraph_capture_sizes = sorted(cudagraph_capture_sizes) - if target_cudagraph_capture_sizes == \ - vllm_config.compilation_config.cudagraph_capture_sizes: - return True - - return False + return target_cudagraph_capture_sizes == vllm_config.compilation_config.cudagraph_capture_sizes def update_default_aclgraph_sizes(vllm_config: VllmConfig) -> None: @@ -461,9 +447,11 @@ def update_default_aclgraph_sizes(vllm_config: VllmConfig) -> None: are more friendly to ascend ops && hardware. """ - if vllm_config.model_config is None or \ - vllm_config.model_config.enforce_eager or \ - not _is_default_capture_sizes(vllm_config): + if ( + vllm_config.model_config is None + or vllm_config.model_config.enforce_eager + or not _is_default_capture_sizes(vllm_config) + ): return # modify the default capture_sizes for Qwen3-MoE models on dp settings. @@ -471,16 +459,15 @@ def update_default_aclgraph_sizes(vllm_config: VllmConfig) -> None: # on special shapes. # TODO(Angazenn): we will remove this once _npu_paged_attention is fully # replaced by npu_fused_infer_attention_score which does not contain such bugs. - if vllm_config.model_config and vllm_config.model_config.hf_text_config.model_type == "qwen3_moe" \ - and vllm_config.parallel_config.tensor_parallel_size == 1 \ - and vllm_config.parallel_config.data_parallel_size > 1 : - + if ( + vllm_config.model_config + and vllm_config.model_config.hf_text_config.model_type == "qwen3_moe" + and vllm_config.parallel_config.tensor_parallel_size == 1 + and vllm_config.parallel_config.data_parallel_size > 1 + ): max_capture_size = vllm_config.compilation_config.max_cudagraph_capture_size - new_cudagraph_capture_sizes = [1, 2, 5, 10, 15, 20] + [ - i for i in range(24, max_capture_size + 1, 8) - ] - update_cudagraph_capture_sizes(vllm_config, - new_cudagraph_capture_sizes) + new_cudagraph_capture_sizes = [1, 2, 5, 10, 15, 20] + [i for i in range(24, max_capture_size + 1, 8)] + update_cudagraph_capture_sizes(vllm_config, new_cudagraph_capture_sizes) def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: @@ -498,8 +485,7 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: # Store original configuration and temporarily clear it compilation_config = vllm_config.compilation_config - original_sizes, compilation_config.cudagraph_capture_sizes = \ - compilation_config.cudagraph_capture_sizes, None + original_sizes, compilation_config.cudagraph_capture_sizes = compilation_config.cudagraph_capture_sizes, None # Calculate parallel configuration factor if not vllm_config.model_config: @@ -510,7 +496,7 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: return hf_config = vllm_config.model_config.hf_text_config - if hasattr(hf_config, 'num_hidden_layers'): + if hasattr(hf_config, "num_hidden_layers"): num_hidden_layers = hf_config.num_hidden_layers else: num_hidden_layers = get_max_hidden_layers(hf_config) @@ -519,24 +505,28 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: # Calculate maximum supported batch sizes considering model architecture resources_per_graph = num_hidden_layers + 1 # For suffix decoding, use the suffix path when no draft_model_config is provided. - if (spec := vllm_config.speculative_config) and \ - (draft := spec.draft_model_config): + if (spec := vllm_config.speculative_config) and (draft := spec.draft_model_config): resources_per_graph += draft.hf_config.num_hidden_layers + 1 # TODO: Find out whether we need to take into account the pp_size - num_comm_groups = sum(size > 1 for size in [ - parallel_config.data_parallel_size, - parallel_config.tensor_parallel_size, - ]) + num_comm_groups = sum( + size > 1 + for size in [ + parallel_config.data_parallel_size, + parallel_config.tensor_parallel_size, + ] + ) - if os.getenv("HCCL_OP_EXPANSION_MODE") == 'AIV': + if os.getenv("HCCL_OP_EXPANSION_MODE") == "AIV": # TODO: Find out whether we need to take into account the pp_size - parallel_factor = 1 + num_comm_groups + int( - parallel_config.enable_expert_parallel) + int( - vllm_config.additional_config.get( - "multistream_overlap_shared_expert", False)) + parallel_factor = ( + 1 + + num_comm_groups + + int(parallel_config.enable_expert_parallel) + + int(vllm_config.additional_config.get("multistream_overlap_shared_expert", False)) + ) if is_moe_model(vllm_config): - parallel_factor += (parallel_config.data_parallel_size > 1) + parallel_factor += parallel_config.data_parallel_size > 1 else: # When AIV mode is enabled, the allreduce operator of the dense # layer model will occupy additional streams, which are buffered here. @@ -546,11 +536,8 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: # Assume the following case: # MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4, # According to the formula, max_num_batch_sizes = math.floor(1920 / (48 + 1) / 2) = 19 - max_num_batch_sizes = math.floor(MAX_CAPTURE_SIZE / - resources_per_graph / parallel_factor) - logger.info( - "Calculated maximum supported batch sizes for ACL graph: %s", - max_num_batch_sizes) + max_num_batch_sizes = math.floor(MAX_CAPTURE_SIZE / resources_per_graph / parallel_factor) + logger.info("Calculated maximum supported batch sizes for ACL graph: %s", max_num_batch_sizes) else: # enable pcp or dcp will add new communication and consume additional approximately less than 100 streams if parallel_config.prefill_context_parallel_size > 1: @@ -562,18 +549,18 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: # Under this configuration, HCCL employs the FFTS+ method for execution unfolding, # which adds only 1 concurrent stream without consuming collective communication execution unfolding streams. # On A3 hardware, HCCL defaults to the AICPU method. - # This approach may additionally allocate up to rank_size (max 16) - 1 streams per collective communication domain on the device (worst case). - # Using the default collective communication unfolding method on A3 will lead to a significant reduction in the maximum supported sizes. + # This approach may additionally allocate up to rank_size (max 16) - 1 streams per collective communication + # domain on the device (worst case). + # Using the default collective communication unfolding method on A3 will lead to a significant reduction + # in the maximum supported sizes. # Therefore, the calculation formula has been modified as follows: # Assume the following case: # MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4, # According to the formula, max_num_batch_sizes = math.floor((1920 - 1 * 40) / (48 + 1) / (1 + 1 * 2)) = 12 max_num_batch_sizes = math.floor( - (MAX_CAPTURE_SIZE - num_comm_groups * 40) / resources_per_graph / - (1 + num_comm_groups * 2)) - logger.info( - "Calculated maximum supported batch sizes for ACL graph: %s", - max_num_batch_sizes) + (MAX_CAPTURE_SIZE - num_comm_groups * 40) / resources_per_graph / (1 + num_comm_groups * 2) + ) + logger.info("Calculated maximum supported batch sizes for ACL graph: %s", max_num_batch_sizes) logger.warning( "Currently, communication is performed using FFTS+ method, which reduces " "the number of available streams and, as a result, limits the range of runtime " @@ -598,26 +585,29 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: vllm_config.model_config.architectures[0], num_hidden_layers, len(original_sizes), - len(compilation_config. - cudagraph_capture_sizes # type: ignore[arg-type] - )) + len( + compilation_config.cudagraph_capture_sizes # type: ignore[arg-type] + ), + ) else: # No adjustment needed compilation_config.cudagraph_capture_sizes = original_sizes logger.info( "No adjustment needed for ACL graph batch sizes: %s model (layers: %d) with %d sizes", - vllm_config.model_config.architectures[0], num_hidden_layers, - len(original_sizes)) + vllm_config.model_config.architectures[0], + num_hidden_layers, + len(original_sizes), + ) # TODO(wxy): Move to ops module def dispose_tensor(x: torch.Tensor): - x.set_(torch.empty((0, ), device=x.device, dtype=x.dtype)) + x.set_(torch.empty((0,), device=x.device, dtype=x.dtype)) class ProfileExecuteDuration: _instance = None - _observations: List[Tuple[str, Event, Event]] = [] + _observations: list[tuple[str, Event, Event]] = [] _lock = Lock() def __new__(cls): @@ -645,8 +635,7 @@ class ProfileExecuteDuration: observe_end = Event(enable_timing=True) observe_end.record() with self._lock: - self._observations.append( - (duration_tag, observe_start, observe_end)) + self._observations.append((duration_tag, observe_start, observe_end)) def pop_captured_sync(self) -> dict: """Pop and synchronize all events in the observation list""" @@ -663,7 +652,7 @@ class ProfileExecuteDuration: return durations -def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): +def register_ascend_customop(vllm_config: VllmConfig | None = None): """Register Ascend CustomOP NOTE: if the register branch requires model type, please use `vllm.config.get_current_vllm_config`, @@ -675,23 +664,29 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): from vllm.model_executor.custom_op import CustomOp from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul - from vllm_ascend.ops.fused_moe.fused_moe import (AscendFusedMoE, - AscendSharedFusedMoE) + from vllm_ascend.ops.fused_moe.fused_moe import AscendFusedMoE, AscendSharedFusedMoE from vllm_ascend.ops.layernorm import AscendGemmaRMSNorm, AscendRMSNorm - from vllm_ascend.ops.linear import (AscendColumnParallelLinear, - AscendMergedColumnParallelLinear, - AscendQKVParallelLinear, - AscendReplicatedLinear, - AscendRowParallelLinear) + from vllm_ascend.ops.linear import ( + AscendColumnParallelLinear, + AscendMergedColumnParallelLinear, + AscendQKVParallelLinear, + AscendReplicatedLinear, + AscendRowParallelLinear, + ) from vllm_ascend.ops.mla import AscendMultiHeadLatentAttention from vllm_ascend.ops.mm_encoder_attention import AscendMMEncoderAttention from vllm_ascend.ops.rotary_embedding import ( - AscendApplyRotaryEmb, AscendDeepseekScalingRotaryEmbedding, - AscendMRotaryEmbedding, AscendRotaryEmbedding, - AscendYaRNRotaryEmbedding) + AscendApplyRotaryEmb, + AscendDeepseekScalingRotaryEmbedding, + AscendMRotaryEmbedding, + AscendRotaryEmbedding, + AscendYaRNRotaryEmbedding, + ) from vllm_ascend.ops.vocab_parallel_embedding import ( - AscendLogitsProcessor, AscendParallelLMHead, - AscendVocabParallelEmbedding) + AscendLogitsProcessor, + AscendParallelLMHead, + AscendVocabParallelEmbedding, + ) global REGISTERED_ASCEND_OPS REGISTERED_ASCEND_OPS = { @@ -738,6 +733,7 @@ _ascend_device_type = None def _init_ascend_device_type(): global _ascend_device_type from vllm_ascend import _build_info # type: ignore + _ascend_device_type = AscendDeviceType[_build_info.__device_type__] @@ -758,7 +754,10 @@ def check_ascend_device_type(): else: raise RuntimeError(f"Can not support soc_version: {soc_version}.") - assert _ascend_device_type == cur_device_type, f"Current device type: {cur_device_type} does not match the installed version's device type: {_ascend_device_type}, please check your installation package." + assert _ascend_device_type == cur_device_type, ( + f"Current device type: {cur_device_type} does not match the installed version's device type: " + f"{_ascend_device_type}, please check your installation package." + ) def get_ascend_device_type(): @@ -769,23 +768,19 @@ def get_ascend_device_type(): def lmhead_tp_enable() -> bool: - return get_ascend_config( - ).finegrained_tp_config.lmhead_tensor_parallel_size > 0 + return get_ascend_config().finegrained_tp_config.lmhead_tensor_parallel_size > 0 def embedding_tp_enable() -> bool: - return get_ascend_config( - ).finegrained_tp_config.embedding_tensor_parallel_size > 0 + return get_ascend_config().finegrained_tp_config.embedding_tensor_parallel_size > 0 def oproj_tp_enable() -> bool: - return get_ascend_config( - ).finegrained_tp_config.oproj_tensor_parallel_size > 0 + return get_ascend_config().finegrained_tp_config.oproj_tensor_parallel_size > 0 def mlp_tp_enable() -> bool: - return get_ascend_config( - ).finegrained_tp_config.mlp_tensor_parallel_size > 0 + return get_ascend_config().finegrained_tp_config.mlp_tensor_parallel_size > 0 def matmul_allreduce_enable() -> bool: @@ -797,30 +792,30 @@ def enable_sp(vllm_config=None, enable_shared_expert_dp: bool = False) -> bool: if _ENABLE_SP is None: if vllm_config is None: from vllm.config import get_current_vllm_config + vllm_config = get_current_vllm_config() _ENABLE_SP = ( vllm_config.compilation_config.pass_config.enable_sp or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1 # Flash comm 1 should be enabled by env VLLM_ASCEND_ENABLE_FLASHCOMM1 # We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility. - or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0')))) + or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0"))) + ) if not _ENABLE_SP and enable_shared_expert_dp: _ENABLE_SP = True - logger.info( - "shared_expert_dp requires enable_sp = True. has set enable_sp to True" - ) + logger.info("shared_expert_dp requires enable_sp = True. has set enable_sp to True") if not _ENABLE_SP: return _ENABLE_SP - assert vllm_config.parallel_config.tensor_parallel_size > 1, \ + assert vllm_config.parallel_config.tensor_parallel_size > 1, ( "Flash Comm v1 (Sequence Parallelism) is only supported when tp_size > 1." + ) - assert ( - not is_moe_model(vllm_config) - or vllm_config.parallel_config.enable_expert_parallel - ), "Flash Comm v1 (Sequence Parallelism) requires enable_expert_parallel=True for MoE models." + assert not is_moe_model(vllm_config) or vllm_config.parallel_config.enable_expert_parallel, ( + "Flash Comm v1 (Sequence Parallelism) requires enable_expert_parallel=True for MoE models." + ) return _ENABLE_SP @@ -847,25 +842,21 @@ def is_drafter_moe_model(vllm_config: VllmConfig): """Checks if the drafter model is a MoE model by config""" global _IS_DRAFTER_MOE_MODEL if _IS_DRAFTER_MOE_MODEL is None: - model_configs = vllm_config.speculative_config.draft_model_config.hf_text_config \ - .to_dict() + model_configs = vllm_config.speculative_config.draft_model_config.hf_text_config.to_dict() _IS_DRAFTER_MOE_MODEL = _is_contain_expert(model_configs) return _IS_DRAFTER_MOE_MODEL -def speculative_enable_dispatch_gmm_combine_decode( - vllm_config: VllmConfig) -> bool: +def speculative_enable_dispatch_gmm_combine_decode(vllm_config: VllmConfig) -> bool: if vllm_config.speculative_config is None: return True - speculative_method = getattr(vllm_config.speculative_config, "method", - None) + speculative_method = getattr(vllm_config.speculative_config, "method", None) if speculative_method in [None, "ngram", "suffix"]: return True if speculative_method in ["eagle", "eagle3"]: return False if speculative_method == "mtp": - mtp_quant_type = getattr(vllm_config.model_config.hf_text_config, - "mtp_quantize", None) + mtp_quant_type = getattr(vllm_config.model_config.hf_text_config, "mtp_quantize", None) return mtp_quant_type == "w8a8_dynamic" return False @@ -915,8 +906,8 @@ def weak_ref_tensor(tensor: Any) -> Any: def weak_ref_tensors( - tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]] -) -> Union[torch.Tensor, list[Any], tuple[Any], Any]: + tensors: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor], +) -> torch.Tensor | list[Any] | tuple[Any] | Any: """ Convenience function to create weak references to tensors, for single tensor, list of tensors or tuple of tensors. @@ -936,17 +927,12 @@ def weak_ref_tensors( return tuple(weak_ref_tensor(t) for t in tensors) # For IntermediateTensors used in pipeline parallelism if isinstance(tensors, IntermediateTensors): - ret = IntermediateTensors({ - key: weak_ref_tensor(val) - for key, val in tensors.tensors.items() - }) + ret = IntermediateTensors({key: weak_ref_tensor(val) for key, val in tensors.tensors.items()}) return ret raise ValueError("Invalid type for tensors") -def npu_stream_switch(target_stream: torch.npu.Stream, - *, - enabled: bool = True): +def npu_stream_switch(target_stream: torch.npu.Stream, *, enabled: bool = True): """ Switch to the target stream if enabled is True. Otherwise, do nothing. @@ -965,7 +951,7 @@ def create_hccl_pg_options(group_name: str): return options -def get_hccl_config_for_pg_options(group_name: str) -> Optional[dict]: +def get_hccl_config_for_pg_options(group_name: str) -> dict | None: """ Get HCCL process group options for the given communication group name. @@ -981,9 +967,7 @@ def get_hccl_config_for_pg_options(group_name: str) -> Optional[dict]: if group_name and "mc2" in group_name: return None hccl_config_map = { - "dp": { - "hccl_buffer_size": calculate_dp_buffer_size() - }, + "dp": {"hccl_buffer_size": calculate_dp_buffer_size()}, } return hccl_config_map.get(group_name, get_default_buffer_config()) @@ -998,6 +982,7 @@ def calculate_dp_buffer_size() -> int: dp_size + 1 (flags: with_prefill) """ from vllm.config import get_current_vllm_config + vllm_config = get_current_vllm_config() dp_size = vllm_config.parallel_config.data_parallel_size int32_size = torch.iinfo(torch.int32).bits // 8 @@ -1009,8 +994,7 @@ def calculate_dp_buffer_size() -> int: # and HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and # significantly improve communication performance of MC2 ops dispatch/combine. def is_hierarchical_communication_enabled(): - return (os.getenv("HCCL_INTRA_ROCE_ENABLE", "") == "0" - and os.getenv("HCCL_INTRA_PCIE_ENABLE", "") == "1") + return os.getenv("HCCL_INTRA_ROCE_ENABLE", "") == "0" and os.getenv("HCCL_INTRA_PCIE_ENABLE", "") == "1" def has_layer_idx(model_instance: torch.nn.Module) -> bool: @@ -1019,8 +1003,7 @@ def has_layer_idx(model_instance: torch.nn.Module) -> bool: global _HAS_LAYER_IDX if _HAS_LAYER_IDX is None: - _HAS_LAYER_IDX = hasattr(model_instance, "model") and \ - hasattr(model_instance.model, "start_layer") + _HAS_LAYER_IDX = hasattr(model_instance, "model") and hasattr(model_instance.model, "start_layer") return _HAS_LAYER_IDX @@ -1042,20 +1025,17 @@ def get_flashcomm2_config_and_validate(ascend_config, vllm_config): if not flashcomm2_enable(): return 0 - logger.info( - f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size = {flashcomm2_oproj_tp_size}" - ) + logger.info(f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size = {flashcomm2_oproj_tp_size}") layer_sharding = ascend_config.layer_sharding or [] if layer_sharding: if layer_sharding == ["o_proj"]: - logger.info_once( - "Enable FLASHCOMM2 with o_proj layer sharding for reduced memory consumption." - ) + logger.info_once("Enable FLASHCOMM2 with o_proj layer sharding for reduced memory consumption.") else: raise ValueError( "FLASHCOMM2 only supports 'o_proj' as the sole layer sharding configuration! " - f"Found invalid layer_sharding: {layer_sharding}") + f"Found invalid layer_sharding: {layer_sharding}" + ) if not envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1: logger.warning_once( "It is recommended to enable FLASHCOMM1 simultaneously when starting FLASHCOMM2 for optimal performance." @@ -1066,32 +1046,35 @@ def get_flashcomm2_config_and_validate(ascend_config, vllm_config): ) if global_tp_size <= flashcomm2_oproj_tp_size: raise AssertionError( - f"flashcomm2_oproj_tensor_parallel_size ({flashcomm2_oproj_tp_size}) cannot exceed global tensor parallel size ({global_tp_size})" + f"flashcomm2_oproj_tensor_parallel_size ({flashcomm2_oproj_tp_size}) cannot exceed " + f"global tensor parallel size ({global_tp_size})" ) if global_tp_size % flashcomm2_oproj_tp_size != 0: raise AssertionError( - f"Global tensor parallel size ({global_tp_size}) must be divisible by flashcomm2_oproj_tensor_parallel_size ({flashcomm2_oproj_tp_size})" + f"Global tensor parallel size ({global_tp_size}) must be divisible by " + f"flashcomm2_oproj_tensor_parallel_size ({flashcomm2_oproj_tp_size})" ) if vllm_config.kv_transfer_config is None: logger.warning_once( - "It is recommended to enable FLASHCOMM2 in P-scenario deployments, enable it in hybrid deployment may lead to decode performance degradation." + "It is recommended to enable FLASHCOMM2 in P-scenario deployments, enable it in hybrid deployment " + "may lead to decode performance degradation." ) if vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_consumer: raise AssertionError( - "FLASHCOMM2 primarily targets P-scenario deployments, with additional support for hybrid deployment scenarios. It is not applicable in D-scenario environments." + "FLASHCOMM2 primarily targets P-scenario deployments, with additional support " + "for hybrid deployment scenarios. It is not applicable in D-scenario environments." ) return flashcomm2_oproj_tp_size def get_flashcomm2_reorgnized_batch_ids(global_tp_size) -> list[list[int]]: - # Reorganize batch_ids so that, after the all2all and reduce-scatter operation, each batch_id corresponds to the rank_id within the DP domain. + # Reorganize batch_ids so that, after the all2all and reduce-scatter operation, + # each batch_id corresponds to the rank_id within the DP domain. # For example, when DP = [0, 1, 2, ..., 15] and flashcomm2_oproj_tensor_parallel_size = 2, # the reorganized batch_ids will be [[batch0, batch8], [batch1, batch9], ..., [batch7, batch15]]. - flashcomm2_otp_size = get_ascend_config( - ).flashcomm2_oproj_tensor_parallel_size - num_oproj_tensor_parallel_groups: int = (global_tp_size // - flashcomm2_otp_size) + flashcomm2_otp_size = get_ascend_config().flashcomm2_oproj_tensor_parallel_size + num_oproj_tensor_parallel_groups: int = global_tp_size // flashcomm2_otp_size reorgnized_batch_ids = [] for i in range(num_oproj_tensor_parallel_groups): @@ -1122,11 +1105,9 @@ def refresh_block_size(vllm_config): return # TODO(MengqingCao): Remove the model_type check, after resolving the hidden error in get_kv_cache_groups. - if not model_config.hf_text_config.model_type == "qwen3_next" and cache_config.block_size != 128: + if model_config.hf_text_config.model_type != "qwen3_next" and cache_config.block_size != 128: if cache_config.enable_prefix_caching or scheduler_config.enable_chunked_prefill: - logger.info( - "Block size is set to 128 if prefix cache or chunked prefill is enabled." - ) + logger.info("Block size is set to 128 if prefix cache or chunked prefill is enabled.") cache_config.block_size = 128 @@ -1138,7 +1119,6 @@ def dispose_layer(layer: Any): def check_kv_extra_config(vllm_config): - def _check(name: str, config: dict): tp_key = "tp_size" dp_key = "dp_size" @@ -1148,24 +1128,21 @@ def check_kv_extra_config(vllm_config): if config_tp != vllm_tp: raise ValueError( f"KV transfer '{name}' config has a conflicting tensor parallel size. " - f"Expected {vllm_tp}, but got {config_tp}.") + f"Expected {vllm_tp}, but got {config_tp}." + ) if dp_key in config: config_dp = config[dp_key] vllm_dp = vllm_config.parallel_config.data_parallel_size if config_dp != vllm_dp: raise ValueError( f"KV transfer '{name}' config has a conflicting data parallel size. " - f"Expected {vllm_dp}, but got {config_dp}.") + f"Expected {vllm_dp}, but got {config_dp}." + ) if vllm_config.kv_transfer_config.is_kv_producer: - _check( - "prefill", - vllm_config.kv_transfer_config.get_from_extra_config( - "prefill", {})) + _check("prefill", vllm_config.kv_transfer_config.get_from_extra_config("prefill", {})) if vllm_config.kv_transfer_config.is_kv_consumer: - _check( - "decode", - vllm_config.kv_transfer_config.get_from_extra_config("decode", {})) + _check("decode", vllm_config.kv_transfer_config.get_from_extra_config("decode", {})) def singleton(cls): @@ -1179,17 +1156,17 @@ def singleton(cls): return get_instance -#TODO: Temporarily use enable_sp to enable the dsa_cp feature of ds32. and subsequent updates will introduce new interfaces. --zzhx1 +# TODO: Temporarily use enable_sp to enable the dsa_cp feature of ds32. +# and subsequent updates will introduce new interfaces. --zzhx1 @lru_cache(maxsize=1) def enable_dsa_cp() -> bool: from vllm.config import get_current_vllm_config + vllm_config = get_current_vllm_config() - is_ds_v32 = hasattr( - vllm_config.model_config, "hf_text_config") and hasattr( - vllm_config.model_config.hf_text_config, "index_topk") - if is_ds_v32 and enable_sp(): - return True - return False + is_ds_v32 = hasattr(vllm_config.model_config, "hf_text_config") and hasattr( + vllm_config.model_config.hf_text_config, "index_topk" + ) + return bool(is_ds_v32 and enable_sp()) @lru_cache(maxsize=1) @@ -1197,6 +1174,7 @@ def enable_dsa_cp_with_layer_shard() -> bool: if not enable_dsa_cp(): return False from vllm.config import get_current_vllm_config + vllm_config = get_current_vllm_config() is_prefill_instance = vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_producer return is_prefill_instance