Files
xc-llm-ascend/vllm_ascend/compilation/acl_graph.py

298 lines
12 KiB
Python
Raw Normal View History

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
from collections.abc import Callable
from contextlib import ExitStack
from dataclasses import dataclass
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
from typing import Any
from unittest.mock import patch
import torch
import torch_npu
import vllm.envs as envs
from vllm.compilation.counter import compilation_counter
from vllm.compilation.cuda_graph import CUDAGraphOptions
from vllm.compilation.monitor import validate_cudagraph_capturing_enabled
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.logger import logger
from vllm.platforms import current_platform
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from ..utils import weak_ref_tensors
@dataclasses.dataclass
class ACLGraphEntry:
batch_descriptor: BatchDescriptor
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
aclgraph: torch.npu.NPUGraph | None = None
output: Any | None = None
# for aclgraph debugging, track the input addresses
# during capture, and check if they are the same during replay
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
input_addresses: list[int] | None = None
class ACLGraphWrapper:
"""Wraps a runnable to add acl graph capturing and replaying ability. And
provide attribute access to the underlying `runnable` via `__getattr__`.
The workflow of this wrapper in the aclgraph dispatching is as follows:
1. At initialization, a runtime mode is assigned to the wrapper (FULL or
PIECEWISE).
2. At runtime, the wrapper receives a runtime_mode and a
batch_descriptor(key) from the forward context and blindly trust them
for aclgraph dispatching.
3. If runtime_mode is NONE or runtime_mode does not match the mode of the
wrapper, just call the runnable directly.
4. Otherwise, i.e., the runtime_mode matches the mode of the wrapper,
the wrapper will perform aclgraph capture(if key does not exist, create
a new entry and cache it) or replay (if key exists in the cache).
Note: ACLGraphWrapper does not store persistent buffers or copy any
runtime inputs into that buffers for replay. We assume implementing them
is done outside of the wrapper. That is because we do not make any
assumption on the dynamic shape (batch size) of the runtime inputs, as a
trade-off for staying orthogonal to compilation logic. Nevertheless,
tracing and checking the input addresses to be consistent during replay is
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
"""
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
def __init__(
self,
runnable: Callable,
vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode,
cudagraph_options: CUDAGraphOptions | None = None,
):
self.runnable = runnable
self.vllm_config = vllm_config
self.runtime_mode = runtime_mode
self.compilation_config = vllm_config.compilation_config
self.first_run_finished = False
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
[Bugfix] Fix slow hasattr in ACLGraphWrapper.__getattr__ (#7442) ### What this PR does / why we need it? Follow https://github.com/vllm-project/vllm/pull/37425, https://github.com/vllm-project/vllm-omni/pull/1982 Copied from them: Notice that `hasattr(self.model, "flush_pending_metadata")` cost 6ms per decode step when profiling Qwen3 Omni. The original `CUDAGraphWrapper.__getattr__` raises: ```python raise AttributeError(f"... cudagraph wrapper: {self.runnable}") ``` When hasattr() is called for a non-existent attribute, Python internally calls __getattr__ which constructs this AttributeError. The {self.runnable} triggers `__repr__()` on the underlying model (e.g., `Qwen3OmniMoeForConditionalGeneration`), which recursivelytraverses the entire nn.Module tree to generate an 18,000+ character string. This takes ~6-7ms per call. Since `hasattr(self.model, "flush_pending_metadata") ` is called every decode step in the Talker forward path, this adds ~6ms overhead per step, severely impacting audio inter-chunk latency (ICL). ```Python hasattr(self.model, "flush_pending_metadata") → getattr(self.model, "flush_pending_metadata") → not found in CUDAGraphWrapper.__dict__ → not found in the CUDAGraphWrapper class hierarchy → triggers CUDAGraphWrapper.__getattr__("flush_pending_metadata") → hasattr(self.runnable, "flush_pending_metadata") # runnable also doesn't have it → executes raise AttributeError(f"... {self.runnable}") → Python needs to construct the exception object → the f-string triggers self.runnable.__repr__() → Qwen3OmniMoeForConditionalGeneration.__repr__() → recursively traverses the entire nn.Module tree → generates a 18,000+ character string → takes ~6 ms → AttributeError object is created → hasattr catches the AttributeError and returns False → the 18,000-character string is immediately discarded (no one ever sees it) ``` ### Does this PR introduce _any_ user-facing change? NO. ### How was this patch tested? See https://github.com/vllm-project/vllm-omni/pull/1982 - vLLM version: v0.17.0 - vLLM main: https://github.com/vllm-project/vllm/commit/4497431df654e46fb1fb5e64bf8611e762ae5d87 --------- Signed-off-by: gcanlin <canlinguosdu@gmail.com>
2026-03-23 09:26:24 +08:00
self._runnable_str = str(runnable) if self.is_debugging_mode else None
# assert runtime_mode is not NONE(no aclgraph), otherwise, we don't
# need to initialize a ACLGraphWrapper.
assert self.runtime_mode != CUDAGraphMode.NONE
upgrade to vllm 0.11.2 (#4400) Bump vLLM version to v0.11.2 What's broken and changed by vLLM: 1. structured_output is broken by https://github.com/vllm-project/vllm/pull/26866 2. get_mrope_input_positions is broken by https://github.com/vllm-project/vllm/pull/28399 3. graph mode is broken by https://github.com/vllm-project/vllm/pull/25110 we'll upgrade torch to 2.8 to fix the problem later 4. embedding is broken by https://github.com/vllm-project/vllm/pull/27583 5. `get_attn_backend_cls` and attention backend is broken are broken by https://github.com/vllm-project/vllm/pull/28534 6. spec decode is broken by https://github.com/vllm-project/vllm/pull/28771 7. sp feature is broken by https://github.com/vllm-project/vllm/pull/27126 8. mtp is broken by https://github.com/vllm-project/vllm/pull/27922 9. lora is broken by https://github.com/vllm-project/vllm/pull/21068 10. execute_model is broken by https://github.com/vllm-project/vllm/pull/26866 11. `VLLM_DISABLE_SHARED_EXPERTS_STREAM` env is broken by https://github.com/vllm-project/vllm/pull/28159 12. kv cahe is broken by https://github.com/vllm-project/vllm/pull/27753 13. dp is broken by https://github.com/vllm-project/vllm/pull/25110 What's broken and changed by ourself: 1. qwen vl is broken by https://github.com/vllm-project/vllm/pull/28455 We'll remove model files in the future to avoid this kind of error 2. Engine core is broken by https://github.com/vllm-project/vllm/pull/23691 We'll remove the patch file in the future. 3. Ascend scheduler is broken by https://github.com/vllm-project/vllm/pull/28733 We'll remove ascend scheudler later. 4. qwen3-next is broken by https://github.com/vllm-project/vllm/pull/28083 We'll remove model files in the future to avoid this kind of error 5. qwen vl is broken by https://github.com/vllm-project/vllm/pull/27764. We'll remove model files in the future Known issue: 1. ray doesn't work 2. the accuracy of qwen3-next is not correct 3. qwen3-vl is broken 4. prefix cache+ ascend scheduler + deepseek v2 lite is broken. Co-authored-by: MengqingCao <cmq0113@163.com> Co-authored-by: hfadzxy <starmoon_zhang@163.com> Co-authored-by: leo-pony <nengjunma@outlook.com> Co-authored-by: 22dimensions <waitingwind@foxmail.com> Co-authored-by: shen-shanshan <467638484@qq.com> - vLLM version: v0.11.2 --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: hfadzxy <starmoon_zhang@163.com> Signed-off-by: leo-pony <nengjunma@outlook.com> Co-authored-by: MengqingCao <cmq0113@163.com> Co-authored-by: hfadzxy <starmoon_zhang@163.com> Co-authored-by: leo-pony <nengjunma@outlook.com>
2025-11-26 11:48:58 +08:00
self.graph_pool = current_platform.get_global_graph_pool()
if cudagraph_options is None:
cudagraph_options = CUDAGraphOptions()
self.aclgraph_options = cudagraph_options
# the entries for different batch descriptors that we need to capture
# aclgraphs for.
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
self.concrete_aclgraph_entries: dict[BatchDescriptor, ACLGraphEntry] = {}
def __getattr__(self, key: str):
# allow accessing the attributes of the runnable.
if hasattr(self.runnable, key):
return getattr(self.runnable, key)
[Bugfix] Fix slow hasattr in ACLGraphWrapper.__getattr__ (#7442) ### What this PR does / why we need it? Follow https://github.com/vllm-project/vllm/pull/37425, https://github.com/vllm-project/vllm-omni/pull/1982 Copied from them: Notice that `hasattr(self.model, "flush_pending_metadata")` cost 6ms per decode step when profiling Qwen3 Omni. The original `CUDAGraphWrapper.__getattr__` raises: ```python raise AttributeError(f"... cudagraph wrapper: {self.runnable}") ``` When hasattr() is called for a non-existent attribute, Python internally calls __getattr__ which constructs this AttributeError. The {self.runnable} triggers `__repr__()` on the underlying model (e.g., `Qwen3OmniMoeForConditionalGeneration`), which recursivelytraverses the entire nn.Module tree to generate an 18,000+ character string. This takes ~6-7ms per call. Since `hasattr(self.model, "flush_pending_metadata") ` is called every decode step in the Talker forward path, this adds ~6ms overhead per step, severely impacting audio inter-chunk latency (ICL). ```Python hasattr(self.model, "flush_pending_metadata") → getattr(self.model, "flush_pending_metadata") → not found in CUDAGraphWrapper.__dict__ → not found in the CUDAGraphWrapper class hierarchy → triggers CUDAGraphWrapper.__getattr__("flush_pending_metadata") → hasattr(self.runnable, "flush_pending_metadata") # runnable also doesn't have it → executes raise AttributeError(f"... {self.runnable}") → Python needs to construct the exception object → the f-string triggers self.runnable.__repr__() → Qwen3OmniMoeForConditionalGeneration.__repr__() → recursively traverses the entire nn.Module tree → generates a 18,000+ character string → takes ~6 ms → AttributeError object is created → hasattr catches the AttributeError and returns False → the 18,000-character string is immediately discarded (no one ever sees it) ``` ### Does this PR introduce _any_ user-facing change? NO. ### How was this patch tested? See https://github.com/vllm-project/vllm-omni/pull/1982 - vLLM version: v0.17.0 - vLLM main: https://github.com/vllm-project/vllm/commit/4497431df654e46fb1fb5e64bf8611e762ae5d87 --------- Signed-off-by: gcanlin <canlinguosdu@gmail.com>
2026-03-23 09:26:24 +08:00
if self.is_debugging_mode:
raise AttributeError(
f"Attribute {key} not exists in the runnable of aclgraph wrapper: {self._runnable_str}"
)
raise AttributeError(f"Attribute {key} not found. Set VLLM_LOGGING_LEVEL=DEBUG for more details.")
def unwrap(self) -> Callable:
# in case we need to access the original runnable.
return self.runnable
def __call__(self, *args, **kwargs):
forward_context = get_forward_context()
batch_descriptor = forward_context.batch_descriptor
aclgraph_runtime_mode = forward_context.cudagraph_runtime_mode
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
if aclgraph_runtime_mode == CUDAGraphMode.NONE or aclgraph_runtime_mode != self.runtime_mode:
# CUDAGraphMode.NONE could mean the profile run, a warmup run, or
# running without aclgraphs.
# We do not trigger capture/replay if the runtime mode is not
# matches. This enables properly dispatching to the correct
# CUDAGraphWrapper when nesting multiple instances with different
# runtime modes.
return self.runnable(*args, **kwargs)
if batch_descriptor not in self.concrete_aclgraph_entries:
# create a new entry for this batch descriptor
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
self.concrete_aclgraph_entries[batch_descriptor] = ACLGraphEntry(batch_descriptor=batch_descriptor)
entry = self.concrete_aclgraph_entries[batch_descriptor]
if entry.aclgraph is None:
if self.aclgraph_options.debug_log_enable:
# Since we capture aclgraph for many different shapes and
# capturing is fast, we don't need to log it for every
# shape. E.g. we only log it for the first subgraph in
# piecewise mode.
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
logger.debug("Capturing a aclgraph on (%s,%s)", self.runtime_mode.name, entry.batch_descriptor)
# validate that aclgraph capturing is legal at this point.
validate_cudagraph_capturing_enabled()
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
input_addresses = [x.data_ptr() for x in args if isinstance(x, torch.Tensor)]
entry.input_addresses = input_addresses
aclgraph = torch.npu.NPUGraph()
with ExitStack() as stack:
if self.aclgraph_options.gc_disable:
# during every model forward for piecewise aclgraph
# mode, we will capture many pieces of aclgraphs
# (roughly one per layer). running gc again and again
# across layers will make the aclgraph capture very slow.
# therefore, we only run gc for the first graph,
# and disable gc for the rest of the graphs.
stack.enter_context(patch("gc.collect", lambda: None))
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
stack.enter_context(patch("torch.npu.empty_cache", lambda: None))
# mind-exploding: carefully manage the reference and memory.
[Feat][Graph] Support `FULL_DECODE_ONLY` mode for GQA/MHA models (#2128) Note: This depends on [vLLM #25161](https://github.com/vllm-project/vllm/pull/25161) and the torch\_npu release from September 30. ### What this PR does / why we need it? This pull request adds `FULL_DECODE_ONLY` mode for GQA/MHA models (MLA models like DeepSeek V3/R1 are not included). Key improvements include: * **Reduced dispatch latency:** By replaying the entire model execution graph at once, we cut overhead compared with multiple smaller replays. * **Stabilized multi-device performance:** Captureing the whole model as one static graph also mitigates the dispatch fluctuations across devices. * **Stream/resource savings:** Consolidating graph captures frees up streams, allowing more graphs to be captured. **Known issues:** 1. `_npu_paged_attention` currently manages its own workspace in `torch_npu`, which can deadlock when synchronizing during graph replay — we’re working on a fix. There may be other corner cases. This PR is the first in a planned series; we’ll continue to iterate and address remaining issues in follow-ups. This is essentially a port of #1503 and #1677, but includes two major changes: 1. Let `graph_dispatcher` decide the graph mode instead of hard-coding it in the backend, which decouples Full Graph and Piecewise Graph and could make it possible to remove dynamo. 2. Adapt to the new `attn_group` logic, but leave a small hack in `update_graph_params`; multi-attention models may or may not be fully supported yet. ### Does this PR introduce _any_ user-facing change? ```python compilation_config={ "cudagraph_mode": "FULL_DECODE_ONLY", }, ``` ### How was this patch tested? Tests included. - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/9607d5eb449711b349d4c2bee0a9c94afcc7ed14 --------- Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-09-22 17:14:28 +08:00
forward_context.capturing = True
with torch.npu.graph(aclgraph, pool=self.graph_pool):
# `output` is managed by pytorch's aclgraph pool
output = self.runnable(*args, **kwargs)
if self.aclgraph_options.weak_ref_output:
# by converting it to weak ref,
# the original `output` will immediately be released
# to save memory. It is only safe to do this for
# the last graph in piecewise aclgraph mode, because
# the output of the last graph will not be used by
# any other acl graph.
output = weak_ref_tensors(output)
# here we always use weak ref for the workspaces
# to save memory
global _graph_params
global _draft_graph_params
weak_ref_workspaces(_graph_params)
weak_ref_workspaces(_draft_graph_params)
# here we always use weak ref for the output
# to save memory
entry.output = weak_ref_tensors(output)
entry.aclgraph = aclgraph
compilation_counter.num_cudagraph_captured += 1
# important: we need to return the output, rather than
# the weak ref of the output, so that pytorch can correctly
# manage the memory during acl graph capture
return output
if self.is_debugging_mode:
# check if the input addresses are the same
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
new_input_addresses = [x.data_ptr() for x in args if isinstance(x, torch.Tensor)]
assert new_input_addresses == entry.input_addresses, (
f"Input addresses for aclgraphs are different "
f"during replay. Expected {entry.input_addresses}, "
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
f"got {new_input_addresses}"
)
logger.info_once("Replaying aclgraph")
# In async scheduling or multi-threaded (MT) scenarios, it is possible that
# the CPU's record event (from update_attn_params) for the iteration i completes
# before the grph replay of iteration i-1.
# To ensure proper ordering, we must call synchronize here before replaying,
# so that update_attn_params only executes after the previous graph replay has fully completed.
# If we do not in main model and in full-graph mode when using merge-eagle-graph,
# we do not need to synchronize.
use_eagle = (
self.vllm_config.speculative_config.method in ("eagle", "eagle3")
if self.vllm_config.speculative_config
else False
)
if self.runtime_mode != CUDAGraphMode.FULL or not _EXTRA_CTX.is_draft_model or not use_eagle:
torch.npu.current_stream().synchronize()
entry.aclgraph.replay()
return entry.output
def weak_ref_workspaces(params):
if params is None:
return
for num_tokens in params.workspaces:
if params.workspaces[num_tokens] is None:
continue
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
params.workspaces[num_tokens] = weak_ref_tensors(params.workspaces[num_tokens])
def update_full_graph_params(
attn_backend,
update_stream,
forward_context,
num_tokens,
vllm_config,
speculative_config=None,
num_dcp_pcp_tokens=None,
draft_attn_metadatas=None,
):
impl_cls = attn_backend.get_impl_cls()
impl_cls.update_graph_params(
update_stream,
forward_context,
num_tokens,
vllm_config,
speculative_config,
num_dcp_pcp_tokens,
draft_attn_metadatas,
)
@dataclass
class GraphParams:
events: dict[int, list[torch.npu.ExternalEvent]]
workspaces: dict[int, torch.Tensor]
handles: dict[int, list[torch_npu._C._NPUTaskGroupHandle]]
attn_params: dict[int, list[tuple]]
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
_graph_params: GraphParams | None = None
[FEAT] Support DeepSeek-V3.2 with `FULL_DECODE_ONLY` mode (#4706) ### What this PR does / why we need it? The first commit support `FULL_DECODE_ONLY`: - Update `AscendSFAMetadataBuilder` to use `num_input_tokens` for slicing slots and positions, ensuring fixed tensor shapes. - Implement padding logic for `query_start_loc` in `NPUModelRunner` to support uniform decode in full graph mode, aligning with GPU runner behavior. - Adjust MLA cosine cache allocation to occur independently of graph mode and switch to using device-resident sequence lengths for attention metadata. - Remove redundant slicing of hidden states and outputs in `AscendSFAImpl` and optimize `sin`/`cos` cache updates. The second commit take MTP into account: - Update `AscendSFAMetadataBuilder` to use `num_input_tokens` for slicing slots and positions, ensuring fixed tensor shapes. - Implement padding logic for `query_start_loc` in `NPUModelRunner` to support uniform decode in full graph mode, aligning with GPU runner behavior. - Adjust MLA cosine cache allocation to occur independently of graph mode and switch to using device-resident sequence lengths for attention metadata. - Remove redundant slicing of hidden states and outputs in `AscendSFAImpl` and optimize `sin`/`cos` cache updates. And the rest of them are just bugfix. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? Test cases needed. - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-12-10 20:11:09 +08:00
def set_graph_params(aclgraph_capture_sizes: list[int]):
global _graph_params
if _graph_params is not None:
raise ValueError("Graph parameters have already been set!")
_graph_params = GraphParams(
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
{size: [] for size in aclgraph_capture_sizes},
{size: None for size in aclgraph_capture_sizes},
{size: [] for size in aclgraph_capture_sizes},
{size: [] for size in aclgraph_capture_sizes},
)
def update_graph_params_workspaces(num_tokens: int, workspace: torch.Tensor):
global _graph_params
if _graph_params is not None:
_graph_params.workspaces[num_tokens] = workspace
def get_graph_params():
return _graph_params
[Feat] Support MTP to running in full graph mode (#3892) ### What this PR does / why we need it? Currently, the MTP model still runs in eager in full graph mode. This PR adapts the MTP with the full graph capture and execution. When the graph mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to improve the performance. The change in both disable_padded_drafter_batch is True and False case include: 1. Add _mtp_graph_params in acl_graph.py to isolate the data of main model and the data of MTP. 2. Padding some metadata in mla_v1.py when in fullgraph mode. 3. Fixed the essential data address that will be used in model.forward. 4. Adapted according to the aclgraph capture framwork: 1). Rebuild MTP model with ACLGraphWrapper. 2). Add common attn metadata when start capture in MTP dummy_run. 3). Add common attn metadata update in MTP. 4). Addapted data update when num_speculative_tokens > 1. 5. Add a patch of MTP to adapt vllm v0.11.0. Existing Issues: 1. When disable_padded_drafter_batch=True and running in FullGraph mode, the data of the first-round requests in MTP is abnormal. We need to identify the cause subsequently. 2. When disable_padded_drafter_batch=False and running in FullGraph mode, the acceptance rate of the second and third tokens will decrease (For example, if we set the num_speculative_tokens=3, the acceptance rate of first token is 90%, the second is only 50% lower than 60%, the third is only 20% lower than 30%). The reason is that the data processed after the model runs does not match. This is a problem from another PR. It works fine in eager and PIECEWISE mode, but has problem in FullGraph mode. Once we have a solution, we will submit a bugfix. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 --------- Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
_draft_graph_params: GraphParams | None = None
[Feat] Support MTP to running in full graph mode (#3892) ### What this PR does / why we need it? Currently, the MTP model still runs in eager in full graph mode. This PR adapts the MTP with the full graph capture and execution. When the graph mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to improve the performance. The change in both disable_padded_drafter_batch is True and False case include: 1. Add _mtp_graph_params in acl_graph.py to isolate the data of main model and the data of MTP. 2. Padding some metadata in mla_v1.py when in fullgraph mode. 3. Fixed the essential data address that will be used in model.forward. 4. Adapted according to the aclgraph capture framwork: 1). Rebuild MTP model with ACLGraphWrapper. 2). Add common attn metadata when start capture in MTP dummy_run. 3). Add common attn metadata update in MTP. 4). Addapted data update when num_speculative_tokens > 1. 5. Add a patch of MTP to adapt vllm v0.11.0. Existing Issues: 1. When disable_padded_drafter_batch=True and running in FullGraph mode, the data of the first-round requests in MTP is abnormal. We need to identify the cause subsequently. 2. When disable_padded_drafter_batch=False and running in FullGraph mode, the acceptance rate of the second and third tokens will decrease (For example, if we set the num_speculative_tokens=3, the acceptance rate of first token is 90%, the second is only 50% lower than 60%, the third is only 20% lower than 30%). The reason is that the data processed after the model runs does not match. This is a problem from another PR. It works fine in eager and PIECEWISE mode, but has problem in FullGraph mode. Once we have a solution, we will submit a bugfix. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 --------- Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
def set_draft_graph_params(aclgraph_capture_sizes: list[int]):
global _draft_graph_params
if _draft_graph_params is not None:
raise ValueError("DraftGraph parameters have already been set!")
_draft_graph_params = GraphParams(
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
{size: [] for size in aclgraph_capture_sizes},
{size: None for size in aclgraph_capture_sizes},
{size: [] for size in aclgraph_capture_sizes},
{size: [] for size in aclgraph_capture_sizes},
[Feat] Support MTP to running in full graph mode (#3892) ### What this PR does / why we need it? Currently, the MTP model still runs in eager in full graph mode. This PR adapts the MTP with the full graph capture and execution. When the graph mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to improve the performance. The change in both disable_padded_drafter_batch is True and False case include: 1. Add _mtp_graph_params in acl_graph.py to isolate the data of main model and the data of MTP. 2. Padding some metadata in mla_v1.py when in fullgraph mode. 3. Fixed the essential data address that will be used in model.forward. 4. Adapted according to the aclgraph capture framwork: 1). Rebuild MTP model with ACLGraphWrapper. 2). Add common attn metadata when start capture in MTP dummy_run. 3). Add common attn metadata update in MTP. 4). Addapted data update when num_speculative_tokens > 1. 5. Add a patch of MTP to adapt vllm v0.11.0. Existing Issues: 1. When disable_padded_drafter_batch=True and running in FullGraph mode, the data of the first-round requests in MTP is abnormal. We need to identify the cause subsequently. 2. When disable_padded_drafter_batch=False and running in FullGraph mode, the acceptance rate of the second and third tokens will decrease (For example, if we set the num_speculative_tokens=3, the acceptance rate of first token is 90%, the second is only 50% lower than 60%, the third is only 20% lower than 30%). The reason is that the data processed after the model runs does not match. This is a problem from another PR. It works fine in eager and PIECEWISE mode, but has problem in FullGraph mode. Once we have a solution, we will submit a bugfix. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 --------- Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
)
def update_draft_graph_params_workspaces(num_tokens: int, workspace: Any):
global _draft_graph_params
if _draft_graph_params is not None:
_draft_graph_params.workspaces[num_tokens] = workspace
[Feat] Support MTP to running in full graph mode (#3892) ### What this PR does / why we need it? Currently, the MTP model still runs in eager in full graph mode. This PR adapts the MTP with the full graph capture and execution. When the graph mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to improve the performance. The change in both disable_padded_drafter_batch is True and False case include: 1. Add _mtp_graph_params in acl_graph.py to isolate the data of main model and the data of MTP. 2. Padding some metadata in mla_v1.py when in fullgraph mode. 3. Fixed the essential data address that will be used in model.forward. 4. Adapted according to the aclgraph capture framwork: 1). Rebuild MTP model with ACLGraphWrapper. 2). Add common attn metadata when start capture in MTP dummy_run. 3). Add common attn metadata update in MTP. 4). Addapted data update when num_speculative_tokens > 1. 5. Add a patch of MTP to adapt vllm v0.11.0. Existing Issues: 1. When disable_padded_drafter_batch=True and running in FullGraph mode, the data of the first-round requests in MTP is abnormal. We need to identify the cause subsequently. 2. When disable_padded_drafter_batch=False and running in FullGraph mode, the acceptance rate of the second and third tokens will decrease (For example, if we set the num_speculative_tokens=3, the acceptance rate of first token is 90%, the second is only 50% lower than 60%, the third is only 20% lower than 30%). The reason is that the data processed after the model runs does not match. This is a problem from another PR. It works fine in eager and PIECEWISE mode, but has problem in FullGraph mode. Once we have a solution, we will submit a bugfix. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 --------- Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
def get_draft_graph_params():
return _draft_graph_params