Files
xc-llm-ascend/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py

244 lines
9.7 KiB
Python
Raw Permalink Normal View History

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
from torch._inductor.pattern_matcher import PatternMatcherPass, PatternPrettyPrinter
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.config.compilation import Range
from vllm.logger import logger
from vllm.model_executor.layers.attention import Attention
from vllm_ascend.compilation.passes.base_pattern import BasePattern
from vllm_ascend.utils import get_rope_dim
[main2main] upgrade vllm main 0202 (#6560) ### What this PR does / why we need it? 1. Fix `TypeError: FusedMoEParallelConfig.__init__() missing 1 required positional argument: 'is_sequence_parallel'` due to https://github.com/vllm-project/vllm/pull/32567 2. Fix ` TypeError: '>' not supported between instances of 'MagicMock' and 'int'` due to https://github.com/vllm-project/vllm/pull/33035 3. Fix `TypeError: Can't instantiate abstract class AscendMLAImpl with abstract methods forward_mha, forward_mqa` and AttributeError: 'bool' object has no attribute 'process_weights_after_loading' due to https://github.com/vllm-project/vllm/pull/33284 4. Fix `'AscendSharedFusedMoE' object has no attribute '_routed_input_transform'`due to https://github.com/vllm-project/vllm/pull/32790 5. Fix `NPUModelRunner._dummy_run() got an unexpected keyword argument 'num_active_loras'` due to https://github.com/vllm-project/vllm/pull/32005 6. Fix the problem caused by` 'tuple' object has no attribute 'job_id'` due to https://github.com/vllm-project/vllm/pull/27492 7. Fix the problem that all_moe_layers is not equal to vllm.moe_forward, vllm.moe_forward_shared due to https://github.com/vllm-project/vllm/pull/33184 8. Add patch to fix the problem "got multiple values for keyword argument 'add_special_tokens'" due to https://github.com/vllm-project/vllm/pull/32863 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0 --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: Meihan-chen <jcccx.cmh@gmail.com> Signed-off-by: hfadzxy <starmoon_zhang@163.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: hfadzxy <starmoon_zhang@163.com>
2026-02-05 19:31:17 +08:00
class QKNormRopeFusionPattern(BasePattern):
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
def __init__(self, vllm_config, head_dim, num_heads, num_kv_heads, eps=1e-6):
super().__init__(vllm_config, eps)
self.head_dim = head_dim
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.device = vllm_config.device_config.device if vllm_config.device_config else None
self.rope_dim = get_rope_dim(vllm_config)
def get_inputs(self):
T = 5
[Main][Ops] Make triton rope support index_selecting from cos_sin_cache (#5450) ### What this PR does / why we need it? This PR extends original `rope_triton_forward` and `split_qkv_rmsnorm_rope` to support `cos_sin_cache` && `positions` as inputs. This fully aligns to vLLM RoPE api interface. Compared with earlier implementation for RoPE, the benefits are: 1. avoiding pre-computation of `cos` `sin` before model execution, which helps to remove redundant codes. 2. allowing eagle3 draft model to have different rope parameters with main model (see #6612 ). This help to recover accept rate && accuracy in that case. In addition, this kernel change only introduces very small performance degradation. Those `index_select` or `chunk` operations are now changed into simple memory access in triton kernel (For example, https://github.com/vllm-project/vllm-ascend/pull/5450/changes#diff-a4c2d3071530df193b98f9bf38553874bc4d47571336711f116c26d019cfbb6aR77-R81). **Highlights** - **RoPE Cache Unification**: Replaced separate _sin and _cos global tensors with a unified cos_sin_cache and explicit positions tensor for Rotary Positional Embeddings (RoPE), streamlining data handling. - **Triton Kernel Integration**: Updated Triton kernels (split_qkv_rmsnorm_rope_kernel, _triton_rope) to directly consume the cos_sin_cache and positions for more efficient and integrated RoPE calculations. - **Custom Operation Registration**: Registered `rope_forward_oot` as a new custom operation, allowing its use in fused compilation passes and providing a dedicated entry point for the new RoPE implementation. - **Refactored RoPE Forward Pass**: Modified the rope_forward_oot function to accept the new cos_sin_cache and positions arguments, enabling a more flexible and integrated RoPE application within the system. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5326c89803566a131c928f7fdd2100b75c981a42 Additional test on Qwen3-235b accuracy: | Aime2024 | GSM8K | Livecodebench | | -------- | -------- | -------- | | 83.33 | 96.26 | 70.23 | --------- Signed-off-by: Angazenn <supperccell@163.com>
2026-02-11 21:20:53 +08:00
max_position_embeddings = 16384
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
qkv = torch.empty(T, self.q_size + 2 * self.kv_size, dtype=torch.bfloat16, device="npu")
q_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
k_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
[Main][Ops] Make triton rope support index_selecting from cos_sin_cache (#5450) ### What this PR does / why we need it? This PR extends original `rope_triton_forward` and `split_qkv_rmsnorm_rope` to support `cos_sin_cache` && `positions` as inputs. This fully aligns to vLLM RoPE api interface. Compared with earlier implementation for RoPE, the benefits are: 1. avoiding pre-computation of `cos` `sin` before model execution, which helps to remove redundant codes. 2. allowing eagle3 draft model to have different rope parameters with main model (see #6612 ). This help to recover accept rate && accuracy in that case. In addition, this kernel change only introduces very small performance degradation. Those `index_select` or `chunk` operations are now changed into simple memory access in triton kernel (For example, https://github.com/vllm-project/vllm-ascend/pull/5450/changes#diff-a4c2d3071530df193b98f9bf38553874bc4d47571336711f116c26d019cfbb6aR77-R81). **Highlights** - **RoPE Cache Unification**: Replaced separate _sin and _cos global tensors with a unified cos_sin_cache and explicit positions tensor for Rotary Positional Embeddings (RoPE), streamlining data handling. - **Triton Kernel Integration**: Updated Triton kernels (split_qkv_rmsnorm_rope_kernel, _triton_rope) to directly consume the cos_sin_cache and positions for more efficient and integrated RoPE calculations. - **Custom Operation Registration**: Registered `rope_forward_oot` as a new custom operation, allowing its use in fused compilation passes and providing a dedicated entry point for the new RoPE implementation. - **Refactored RoPE Forward Pass**: Modified the rope_forward_oot function to accept the new cos_sin_cache and positions arguments, enabling a more flexible and integrated RoPE application within the system. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5326c89803566a131c928f7fdd2100b75c981a42 Additional test on Qwen3-235b accuracy: | Aime2024 | GSM8K | Livecodebench | | -------- | -------- | -------- | | 83.33 | 96.26 | 70.23 | --------- Signed-off-by: Angazenn <supperccell@163.com>
2026-02-11 21:20:53 +08:00
cos_sin_cache = torch.empty(max_position_embeddings, self.head_dim, dtype=torch.bfloat16, device="npu")
positions = torch.ones(T, dtype=torch.int64, device="npu")
return [qkv, q_weight, k_weight, cos_sin_cache, positions]
def get_pattern(self):
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
def pattern(
[Main][Ops] Make triton rope support index_selecting from cos_sin_cache (#5450) ### What this PR does / why we need it? This PR extends original `rope_triton_forward` and `split_qkv_rmsnorm_rope` to support `cos_sin_cache` && `positions` as inputs. This fully aligns to vLLM RoPE api interface. Compared with earlier implementation for RoPE, the benefits are: 1. avoiding pre-computation of `cos` `sin` before model execution, which helps to remove redundant codes. 2. allowing eagle3 draft model to have different rope parameters with main model (see #6612 ). This help to recover accept rate && accuracy in that case. In addition, this kernel change only introduces very small performance degradation. Those `index_select` or `chunk` operations are now changed into simple memory access in triton kernel (For example, https://github.com/vllm-project/vllm-ascend/pull/5450/changes#diff-a4c2d3071530df193b98f9bf38553874bc4d47571336711f116c26d019cfbb6aR77-R81). **Highlights** - **RoPE Cache Unification**: Replaced separate _sin and _cos global tensors with a unified cos_sin_cache and explicit positions tensor for Rotary Positional Embeddings (RoPE), streamlining data handling. - **Triton Kernel Integration**: Updated Triton kernels (split_qkv_rmsnorm_rope_kernel, _triton_rope) to directly consume the cos_sin_cache and positions for more efficient and integrated RoPE calculations. - **Custom Operation Registration**: Registered `rope_forward_oot` as a new custom operation, allowing its use in fused compilation passes and providing a dedicated entry point for the new RoPE implementation. - **Refactored RoPE Forward Pass**: Modified the rope_forward_oot function to accept the new cos_sin_cache and positions arguments, enabling a more flexible and integrated RoPE application within the system. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5326c89803566a131c928f7fdd2100b75c981a42 Additional test on Qwen3-235b accuracy: | Aime2024 | GSM8K | Livecodebench | | -------- | -------- | -------- | | 83.33 | 96.26 | 70.23 | --------- Signed-off-by: Angazenn <supperccell@163.com>
2026-02-11 21:20:53 +08:00
qkv: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
cos_sin_cache: torch.Tensor,
positions: torch.Tensor,
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
):
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight, self.eps)
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight, self.eps)
q_flat = q_norm_out.view(q.shape)
k_flat = k_norm_out.view(k.shape)
[Main][Ops] Make triton rope support index_selecting from cos_sin_cache (#5450) ### What this PR does / why we need it? This PR extends original `rope_triton_forward` and `split_qkv_rmsnorm_rope` to support `cos_sin_cache` && `positions` as inputs. This fully aligns to vLLM RoPE api interface. Compared with earlier implementation for RoPE, the benefits are: 1. avoiding pre-computation of `cos` `sin` before model execution, which helps to remove redundant codes. 2. allowing eagle3 draft model to have different rope parameters with main model (see #6612 ). This help to recover accept rate && accuracy in that case. In addition, this kernel change only introduces very small performance degradation. Those `index_select` or `chunk` operations are now changed into simple memory access in triton kernel (For example, https://github.com/vllm-project/vllm-ascend/pull/5450/changes#diff-a4c2d3071530df193b98f9bf38553874bc4d47571336711f116c26d019cfbb6aR77-R81). **Highlights** - **RoPE Cache Unification**: Replaced separate _sin and _cos global tensors with a unified cos_sin_cache and explicit positions tensor for Rotary Positional Embeddings (RoPE), streamlining data handling. - **Triton Kernel Integration**: Updated Triton kernels (split_qkv_rmsnorm_rope_kernel, _triton_rope) to directly consume the cos_sin_cache and positions for more efficient and integrated RoPE calculations. - **Custom Operation Registration**: Registered `rope_forward_oot` as a new custom operation, allowing its use in fused compilation passes and providing a dedicated entry point for the new RoPE implementation. - **Refactored RoPE Forward Pass**: Modified the rope_forward_oot function to accept the new cos_sin_cache and positions arguments, enabling a more flexible and integrated RoPE application within the system. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5326c89803566a131c928f7fdd2100b75c981a42 Additional test on Qwen3-235b accuracy: | Aime2024 | GSM8K | Livecodebench | | -------- | -------- | -------- | | 83.33 | 96.26 | 70.23 | --------- Signed-off-by: Angazenn <supperccell@163.com>
2026-02-11 21:20:53 +08:00
q_rope, k_rope = torch.ops.vllm.npu_rotary_embedding(
positions, q_flat, k_flat, cos_sin_cache, self.head_dim, self.rope_dim, True
[Main][Ops] Make triton rope support index_selecting from cos_sin_cache (#5450) ### What this PR does / why we need it? This PR extends original `rope_triton_forward` and `split_qkv_rmsnorm_rope` to support `cos_sin_cache` && `positions` as inputs. This fully aligns to vLLM RoPE api interface. Compared with earlier implementation for RoPE, the benefits are: 1. avoiding pre-computation of `cos` `sin` before model execution, which helps to remove redundant codes. 2. allowing eagle3 draft model to have different rope parameters with main model (see #6612 ). This help to recover accept rate && accuracy in that case. In addition, this kernel change only introduces very small performance degradation. Those `index_select` or `chunk` operations are now changed into simple memory access in triton kernel (For example, https://github.com/vllm-project/vllm-ascend/pull/5450/changes#diff-a4c2d3071530df193b98f9bf38553874bc4d47571336711f116c26d019cfbb6aR77-R81). **Highlights** - **RoPE Cache Unification**: Replaced separate _sin and _cos global tensors with a unified cos_sin_cache and explicit positions tensor for Rotary Positional Embeddings (RoPE), streamlining data handling. - **Triton Kernel Integration**: Updated Triton kernels (split_qkv_rmsnorm_rope_kernel, _triton_rope) to directly consume the cos_sin_cache and positions for more efficient and integrated RoPE calculations. - **Custom Operation Registration**: Registered `rope_forward_oot` as a new custom operation, allowing its use in fused compilation passes and providing a dedicated entry point for the new RoPE implementation. - **Refactored RoPE Forward Pass**: Modified the rope_forward_oot function to accept the new cos_sin_cache and positions arguments, enabling a more flexible and integrated RoPE application within the system. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5326c89803566a131c928f7fdd2100b75c981a42 Additional test on Qwen3-235b accuracy: | Aime2024 | GSM8K | Livecodebench | | -------- | -------- | -------- | | 83.33 | 96.26 | 70.23 | --------- Signed-off-by: Angazenn <supperccell@163.com>
2026-02-11 21:20:53 +08:00
)
return q_rope, k_rope, v
return pattern
def get_replacement(self):
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
def replacement(
[Main][Ops] Make triton rope support index_selecting from cos_sin_cache (#5450) ### What this PR does / why we need it? This PR extends original `rope_triton_forward` and `split_qkv_rmsnorm_rope` to support `cos_sin_cache` && `positions` as inputs. This fully aligns to vLLM RoPE api interface. Compared with earlier implementation for RoPE, the benefits are: 1. avoiding pre-computation of `cos` `sin` before model execution, which helps to remove redundant codes. 2. allowing eagle3 draft model to have different rope parameters with main model (see #6612 ). This help to recover accept rate && accuracy in that case. In addition, this kernel change only introduces very small performance degradation. Those `index_select` or `chunk` operations are now changed into simple memory access in triton kernel (For example, https://github.com/vllm-project/vllm-ascend/pull/5450/changes#diff-a4c2d3071530df193b98f9bf38553874bc4d47571336711f116c26d019cfbb6aR77-R81). **Highlights** - **RoPE Cache Unification**: Replaced separate _sin and _cos global tensors with a unified cos_sin_cache and explicit positions tensor for Rotary Positional Embeddings (RoPE), streamlining data handling. - **Triton Kernel Integration**: Updated Triton kernels (split_qkv_rmsnorm_rope_kernel, _triton_rope) to directly consume the cos_sin_cache and positions for more efficient and integrated RoPE calculations. - **Custom Operation Registration**: Registered `rope_forward_oot` as a new custom operation, allowing its use in fused compilation passes and providing a dedicated entry point for the new RoPE implementation. - **Refactored RoPE Forward Pass**: Modified the rope_forward_oot function to accept the new cos_sin_cache and positions arguments, enabling a more flexible and integrated RoPE application within the system. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5326c89803566a131c928f7fdd2100b75c981a42 Additional test on Qwen3-235b accuracy: | Aime2024 | GSM8K | Livecodebench | | -------- | -------- | -------- | | 83.33 | 96.26 | 70.23 | --------- Signed-off-by: Angazenn <supperccell@163.com>
2026-02-11 21:20:53 +08:00
qkv: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
cos_sin_cache: torch.Tensor,
positions: torch.Tensor,
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
):
results = torch.ops.vllm.qkv_rmsnorm_rope(
input=qkv,
q_weight=q_weight,
k_weight=k_weight,
q_hidden_size=self.q_size,
kv_hidden_size=self.kv_size,
head_dim=self.head_dim,
eps=self.eps,
q_bias=None,
k_bias=None,
[Main][Ops] Make triton rope support index_selecting from cos_sin_cache (#5450) ### What this PR does / why we need it? This PR extends original `rope_triton_forward` and `split_qkv_rmsnorm_rope` to support `cos_sin_cache` && `positions` as inputs. This fully aligns to vLLM RoPE api interface. Compared with earlier implementation for RoPE, the benefits are: 1. avoiding pre-computation of `cos` `sin` before model execution, which helps to remove redundant codes. 2. allowing eagle3 draft model to have different rope parameters with main model (see #6612 ). This help to recover accept rate && accuracy in that case. In addition, this kernel change only introduces very small performance degradation. Those `index_select` or `chunk` operations are now changed into simple memory access in triton kernel (For example, https://github.com/vllm-project/vllm-ascend/pull/5450/changes#diff-a4c2d3071530df193b98f9bf38553874bc4d47571336711f116c26d019cfbb6aR77-R81). **Highlights** - **RoPE Cache Unification**: Replaced separate _sin and _cos global tensors with a unified cos_sin_cache and explicit positions tensor for Rotary Positional Embeddings (RoPE), streamlining data handling. - **Triton Kernel Integration**: Updated Triton kernels (split_qkv_rmsnorm_rope_kernel, _triton_rope) to directly consume the cos_sin_cache and positions for more efficient and integrated RoPE calculations. - **Custom Operation Registration**: Registered `rope_forward_oot` as a new custom operation, allowing its use in fused compilation passes and providing a dedicated entry point for the new RoPE implementation. - **Refactored RoPE Forward Pass**: Modified the rope_forward_oot function to accept the new cos_sin_cache and positions arguments, enabling a more flexible and integrated RoPE application within the system. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5326c89803566a131c928f7fdd2100b75c981a42 Additional test on Qwen3-235b accuracy: | Aime2024 | GSM8K | Livecodebench | | -------- | -------- | -------- | | 83.33 | 96.26 | 70.23 | --------- Signed-off-by: Angazenn <supperccell@163.com>
2026-02-11 21:20:53 +08:00
cos_sin_cache=cos_sin_cache,
positions=positions,
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
)
return results
return replacement
class QKNormRopeFusionPatternWithBias(BasePattern):
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
def __init__(self, vllm_config, head_dim, num_heads, num_kv_heads, eps=1e-6):
super().__init__(vllm_config, eps)
self.head_dim = head_dim
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.device = vllm_config.device_config.device if vllm_config.device_config else None
self.rope_dim = get_rope_dim(vllm_config)
def get_inputs(self):
T = 5
[Main][Ops] Make triton rope support index_selecting from cos_sin_cache (#5450) ### What this PR does / why we need it? This PR extends original `rope_triton_forward` and `split_qkv_rmsnorm_rope` to support `cos_sin_cache` && `positions` as inputs. This fully aligns to vLLM RoPE api interface. Compared with earlier implementation for RoPE, the benefits are: 1. avoiding pre-computation of `cos` `sin` before model execution, which helps to remove redundant codes. 2. allowing eagle3 draft model to have different rope parameters with main model (see #6612 ). This help to recover accept rate && accuracy in that case. In addition, this kernel change only introduces very small performance degradation. Those `index_select` or `chunk` operations are now changed into simple memory access in triton kernel (For example, https://github.com/vllm-project/vllm-ascend/pull/5450/changes#diff-a4c2d3071530df193b98f9bf38553874bc4d47571336711f116c26d019cfbb6aR77-R81). **Highlights** - **RoPE Cache Unification**: Replaced separate _sin and _cos global tensors with a unified cos_sin_cache and explicit positions tensor for Rotary Positional Embeddings (RoPE), streamlining data handling. - **Triton Kernel Integration**: Updated Triton kernels (split_qkv_rmsnorm_rope_kernel, _triton_rope) to directly consume the cos_sin_cache and positions for more efficient and integrated RoPE calculations. - **Custom Operation Registration**: Registered `rope_forward_oot` as a new custom operation, allowing its use in fused compilation passes and providing a dedicated entry point for the new RoPE implementation. - **Refactored RoPE Forward Pass**: Modified the rope_forward_oot function to accept the new cos_sin_cache and positions arguments, enabling a more flexible and integrated RoPE application within the system. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5326c89803566a131c928f7fdd2100b75c981a42 Additional test on Qwen3-235b accuracy: | Aime2024 | GSM8K | Livecodebench | | -------- | -------- | -------- | | 83.33 | 96.26 | 70.23 | --------- Signed-off-by: Angazenn <supperccell@163.com>
2026-02-11 21:20:53 +08:00
max_position_embeddings = 16384
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
qkv = torch.empty(T, self.q_size + 2 * self.kv_size, dtype=torch.bfloat16, device="npu")
q_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
k_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
q_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
k_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
[Main][Ops] Make triton rope support index_selecting from cos_sin_cache (#5450) ### What this PR does / why we need it? This PR extends original `rope_triton_forward` and `split_qkv_rmsnorm_rope` to support `cos_sin_cache` && `positions` as inputs. This fully aligns to vLLM RoPE api interface. Compared with earlier implementation for RoPE, the benefits are: 1. avoiding pre-computation of `cos` `sin` before model execution, which helps to remove redundant codes. 2. allowing eagle3 draft model to have different rope parameters with main model (see #6612 ). This help to recover accept rate && accuracy in that case. In addition, this kernel change only introduces very small performance degradation. Those `index_select` or `chunk` operations are now changed into simple memory access in triton kernel (For example, https://github.com/vllm-project/vllm-ascend/pull/5450/changes#diff-a4c2d3071530df193b98f9bf38553874bc4d47571336711f116c26d019cfbb6aR77-R81). **Highlights** - **RoPE Cache Unification**: Replaced separate _sin and _cos global tensors with a unified cos_sin_cache and explicit positions tensor for Rotary Positional Embeddings (RoPE), streamlining data handling. - **Triton Kernel Integration**: Updated Triton kernels (split_qkv_rmsnorm_rope_kernel, _triton_rope) to directly consume the cos_sin_cache and positions for more efficient and integrated RoPE calculations. - **Custom Operation Registration**: Registered `rope_forward_oot` as a new custom operation, allowing its use in fused compilation passes and providing a dedicated entry point for the new RoPE implementation. - **Refactored RoPE Forward Pass**: Modified the rope_forward_oot function to accept the new cos_sin_cache and positions arguments, enabling a more flexible and integrated RoPE application within the system. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5326c89803566a131c928f7fdd2100b75c981a42 Additional test on Qwen3-235b accuracy: | Aime2024 | GSM8K | Livecodebench | | -------- | -------- | -------- | | 83.33 | 96.26 | 70.23 | --------- Signed-off-by: Angazenn <supperccell@163.com>
2026-02-11 21:20:53 +08:00
cos_sin_cache = torch.empty(max_position_embeddings, self.head_dim, dtype=torch.bfloat16, device="npu")
positions = torch.ones(T, dtype=torch.int64, device="npu")
[Main][Ops] Make triton rope support index_selecting from cos_sin_cache (#5450) ### What this PR does / why we need it? This PR extends original `rope_triton_forward` and `split_qkv_rmsnorm_rope` to support `cos_sin_cache` && `positions` as inputs. This fully aligns to vLLM RoPE api interface. Compared with earlier implementation for RoPE, the benefits are: 1. avoiding pre-computation of `cos` `sin` before model execution, which helps to remove redundant codes. 2. allowing eagle3 draft model to have different rope parameters with main model (see #6612 ). This help to recover accept rate && accuracy in that case. In addition, this kernel change only introduces very small performance degradation. Those `index_select` or `chunk` operations are now changed into simple memory access in triton kernel (For example, https://github.com/vllm-project/vllm-ascend/pull/5450/changes#diff-a4c2d3071530df193b98f9bf38553874bc4d47571336711f116c26d019cfbb6aR77-R81). **Highlights** - **RoPE Cache Unification**: Replaced separate _sin and _cos global tensors with a unified cos_sin_cache and explicit positions tensor for Rotary Positional Embeddings (RoPE), streamlining data handling. - **Triton Kernel Integration**: Updated Triton kernels (split_qkv_rmsnorm_rope_kernel, _triton_rope) to directly consume the cos_sin_cache and positions for more efficient and integrated RoPE calculations. - **Custom Operation Registration**: Registered `rope_forward_oot` as a new custom operation, allowing its use in fused compilation passes and providing a dedicated entry point for the new RoPE implementation. - **Refactored RoPE Forward Pass**: Modified the rope_forward_oot function to accept the new cos_sin_cache and positions arguments, enabling a more flexible and integrated RoPE application within the system. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5326c89803566a131c928f7fdd2100b75c981a42 Additional test on Qwen3-235b accuracy: | Aime2024 | GSM8K | Livecodebench | | -------- | -------- | -------- | | 83.33 | 96.26 | 70.23 | --------- Signed-off-by: Angazenn <supperccell@163.com>
2026-02-11 21:20:53 +08:00
return [qkv, q_weight, k_weight, q_bias, k_bias, cos_sin_cache, positions]
def get_pattern(self):
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
def pattern(
qkv: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_bias: torch.Tensor,
k_bias: torch.Tensor,
[Main][Ops] Make triton rope support index_selecting from cos_sin_cache (#5450) ### What this PR does / why we need it? This PR extends original `rope_triton_forward` and `split_qkv_rmsnorm_rope` to support `cos_sin_cache` && `positions` as inputs. This fully aligns to vLLM RoPE api interface. Compared with earlier implementation for RoPE, the benefits are: 1. avoiding pre-computation of `cos` `sin` before model execution, which helps to remove redundant codes. 2. allowing eagle3 draft model to have different rope parameters with main model (see #6612 ). This help to recover accept rate && accuracy in that case. In addition, this kernel change only introduces very small performance degradation. Those `index_select` or `chunk` operations are now changed into simple memory access in triton kernel (For example, https://github.com/vllm-project/vllm-ascend/pull/5450/changes#diff-a4c2d3071530df193b98f9bf38553874bc4d47571336711f116c26d019cfbb6aR77-R81). **Highlights** - **RoPE Cache Unification**: Replaced separate _sin and _cos global tensors with a unified cos_sin_cache and explicit positions tensor for Rotary Positional Embeddings (RoPE), streamlining data handling. - **Triton Kernel Integration**: Updated Triton kernels (split_qkv_rmsnorm_rope_kernel, _triton_rope) to directly consume the cos_sin_cache and positions for more efficient and integrated RoPE calculations. - **Custom Operation Registration**: Registered `rope_forward_oot` as a new custom operation, allowing its use in fused compilation passes and providing a dedicated entry point for the new RoPE implementation. - **Refactored RoPE Forward Pass**: Modified the rope_forward_oot function to accept the new cos_sin_cache and positions arguments, enabling a more flexible and integrated RoPE application within the system. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5326c89803566a131c928f7fdd2100b75c981a42 Additional test on Qwen3-235b accuracy: | Aime2024 | GSM8K | Livecodebench | | -------- | -------- | -------- | | 83.33 | 96.26 | 70.23 | --------- Signed-off-by: Angazenn <supperccell@163.com>
2026-02-11 21:20:53 +08:00
cos_sin_cache: torch.Tensor,
positions: torch.Tensor,
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
):
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight, self.eps)
q_normed = q_norm_out + q_bias
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight, self.eps)
k_normed = k_norm_out + k_bias
q_flat = q_normed.view(q.shape)
k_flat = k_normed.view(k.shape)
[Main][Ops] Make triton rope support index_selecting from cos_sin_cache (#5450) ### What this PR does / why we need it? This PR extends original `rope_triton_forward` and `split_qkv_rmsnorm_rope` to support `cos_sin_cache` && `positions` as inputs. This fully aligns to vLLM RoPE api interface. Compared with earlier implementation for RoPE, the benefits are: 1. avoiding pre-computation of `cos` `sin` before model execution, which helps to remove redundant codes. 2. allowing eagle3 draft model to have different rope parameters with main model (see #6612 ). This help to recover accept rate && accuracy in that case. In addition, this kernel change only introduces very small performance degradation. Those `index_select` or `chunk` operations are now changed into simple memory access in triton kernel (For example, https://github.com/vllm-project/vllm-ascend/pull/5450/changes#diff-a4c2d3071530df193b98f9bf38553874bc4d47571336711f116c26d019cfbb6aR77-R81). **Highlights** - **RoPE Cache Unification**: Replaced separate _sin and _cos global tensors with a unified cos_sin_cache and explicit positions tensor for Rotary Positional Embeddings (RoPE), streamlining data handling. - **Triton Kernel Integration**: Updated Triton kernels (split_qkv_rmsnorm_rope_kernel, _triton_rope) to directly consume the cos_sin_cache and positions for more efficient and integrated RoPE calculations. - **Custom Operation Registration**: Registered `rope_forward_oot` as a new custom operation, allowing its use in fused compilation passes and providing a dedicated entry point for the new RoPE implementation. - **Refactored RoPE Forward Pass**: Modified the rope_forward_oot function to accept the new cos_sin_cache and positions arguments, enabling a more flexible and integrated RoPE application within the system. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5326c89803566a131c928f7fdd2100b75c981a42 Additional test on Qwen3-235b accuracy: | Aime2024 | GSM8K | Livecodebench | | -------- | -------- | -------- | | 83.33 | 96.26 | 70.23 | --------- Signed-off-by: Angazenn <supperccell@163.com>
2026-02-11 21:20:53 +08:00
q_rope, k_rope = torch.ops.vllm.npu_rotary_embedding(
positions, q_flat, k_flat, cos_sin_cache, self.head_dim, self.rope_dim, True
[Main][Ops] Make triton rope support index_selecting from cos_sin_cache (#5450) ### What this PR does / why we need it? This PR extends original `rope_triton_forward` and `split_qkv_rmsnorm_rope` to support `cos_sin_cache` && `positions` as inputs. This fully aligns to vLLM RoPE api interface. Compared with earlier implementation for RoPE, the benefits are: 1. avoiding pre-computation of `cos` `sin` before model execution, which helps to remove redundant codes. 2. allowing eagle3 draft model to have different rope parameters with main model (see #6612 ). This help to recover accept rate && accuracy in that case. In addition, this kernel change only introduces very small performance degradation. Those `index_select` or `chunk` operations are now changed into simple memory access in triton kernel (For example, https://github.com/vllm-project/vllm-ascend/pull/5450/changes#diff-a4c2d3071530df193b98f9bf38553874bc4d47571336711f116c26d019cfbb6aR77-R81). **Highlights** - **RoPE Cache Unification**: Replaced separate _sin and _cos global tensors with a unified cos_sin_cache and explicit positions tensor for Rotary Positional Embeddings (RoPE), streamlining data handling. - **Triton Kernel Integration**: Updated Triton kernels (split_qkv_rmsnorm_rope_kernel, _triton_rope) to directly consume the cos_sin_cache and positions for more efficient and integrated RoPE calculations. - **Custom Operation Registration**: Registered `rope_forward_oot` as a new custom operation, allowing its use in fused compilation passes and providing a dedicated entry point for the new RoPE implementation. - **Refactored RoPE Forward Pass**: Modified the rope_forward_oot function to accept the new cos_sin_cache and positions arguments, enabling a more flexible and integrated RoPE application within the system. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5326c89803566a131c928f7fdd2100b75c981a42 Additional test on Qwen3-235b accuracy: | Aime2024 | GSM8K | Livecodebench | | -------- | -------- | -------- | | 83.33 | 96.26 | 70.23 | --------- Signed-off-by: Angazenn <supperccell@163.com>
2026-02-11 21:20:53 +08:00
)
return q_rope, k_rope, v
return pattern
def get_replacement(self):
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
def replacement(
qkv: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_bias: torch.Tensor,
k_bias: torch.Tensor,
[Main][Ops] Make triton rope support index_selecting from cos_sin_cache (#5450) ### What this PR does / why we need it? This PR extends original `rope_triton_forward` and `split_qkv_rmsnorm_rope` to support `cos_sin_cache` && `positions` as inputs. This fully aligns to vLLM RoPE api interface. Compared with earlier implementation for RoPE, the benefits are: 1. avoiding pre-computation of `cos` `sin` before model execution, which helps to remove redundant codes. 2. allowing eagle3 draft model to have different rope parameters with main model (see #6612 ). This help to recover accept rate && accuracy in that case. In addition, this kernel change only introduces very small performance degradation. Those `index_select` or `chunk` operations are now changed into simple memory access in triton kernel (For example, https://github.com/vllm-project/vllm-ascend/pull/5450/changes#diff-a4c2d3071530df193b98f9bf38553874bc4d47571336711f116c26d019cfbb6aR77-R81). **Highlights** - **RoPE Cache Unification**: Replaced separate _sin and _cos global tensors with a unified cos_sin_cache and explicit positions tensor for Rotary Positional Embeddings (RoPE), streamlining data handling. - **Triton Kernel Integration**: Updated Triton kernels (split_qkv_rmsnorm_rope_kernel, _triton_rope) to directly consume the cos_sin_cache and positions for more efficient and integrated RoPE calculations. - **Custom Operation Registration**: Registered `rope_forward_oot` as a new custom operation, allowing its use in fused compilation passes and providing a dedicated entry point for the new RoPE implementation. - **Refactored RoPE Forward Pass**: Modified the rope_forward_oot function to accept the new cos_sin_cache and positions arguments, enabling a more flexible and integrated RoPE application within the system. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5326c89803566a131c928f7fdd2100b75c981a42 Additional test on Qwen3-235b accuracy: | Aime2024 | GSM8K | Livecodebench | | -------- | -------- | -------- | | 83.33 | 96.26 | 70.23 | --------- Signed-off-by: Angazenn <supperccell@163.com>
2026-02-11 21:20:53 +08:00
cos_sin_cache: torch.Tensor,
positions: torch.Tensor,
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
):
results = torch.ops.vllm.qkv_rmsnorm_rope(
input=qkv,
q_weight=q_weight,
k_weight=k_weight,
q_hidden_size=self.q_size,
kv_hidden_size=self.kv_size,
head_dim=self.head_dim,
eps=self.eps,
q_bias=q_bias,
k_bias=k_bias,
[Main][Ops] Make triton rope support index_selecting from cos_sin_cache (#5450) ### What this PR does / why we need it? This PR extends original `rope_triton_forward` and `split_qkv_rmsnorm_rope` to support `cos_sin_cache` && `positions` as inputs. This fully aligns to vLLM RoPE api interface. Compared with earlier implementation for RoPE, the benefits are: 1. avoiding pre-computation of `cos` `sin` before model execution, which helps to remove redundant codes. 2. allowing eagle3 draft model to have different rope parameters with main model (see #6612 ). This help to recover accept rate && accuracy in that case. In addition, this kernel change only introduces very small performance degradation. Those `index_select` or `chunk` operations are now changed into simple memory access in triton kernel (For example, https://github.com/vllm-project/vllm-ascend/pull/5450/changes#diff-a4c2d3071530df193b98f9bf38553874bc4d47571336711f116c26d019cfbb6aR77-R81). **Highlights** - **RoPE Cache Unification**: Replaced separate _sin and _cos global tensors with a unified cos_sin_cache and explicit positions tensor for Rotary Positional Embeddings (RoPE), streamlining data handling. - **Triton Kernel Integration**: Updated Triton kernels (split_qkv_rmsnorm_rope_kernel, _triton_rope) to directly consume the cos_sin_cache and positions for more efficient and integrated RoPE calculations. - **Custom Operation Registration**: Registered `rope_forward_oot` as a new custom operation, allowing its use in fused compilation passes and providing a dedicated entry point for the new RoPE implementation. - **Refactored RoPE Forward Pass**: Modified the rope_forward_oot function to accept the new cos_sin_cache and positions arguments, enabling a more flexible and integrated RoPE application within the system. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5326c89803566a131c928f7fdd2100b75c981a42 Additional test on Qwen3-235b accuracy: | Aime2024 | GSM8K | Livecodebench | | -------- | -------- | -------- | | 83.33 | 96.26 | 70.23 | --------- Signed-off-by: Angazenn <supperccell@163.com>
2026-02-11 21:20:53 +08:00
cos_sin_cache=cos_sin_cache,
positions=positions,
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
)
return results
return replacement
class QKNormRopeFusionPass(VllmInductorPass):
"""
A pass for fusing QKV split and RMSNorm operations into a single qk_rmsnorm operator.
"""
def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(pass_name="qknorm_rope_fusion_pass")
dtype = vllm_config.model_config.dtype
[Main][Ops] Make triton rope support index_selecting from cos_sin_cache (#5450) ### What this PR does / why we need it? This PR extends original `rope_triton_forward` and `split_qkv_rmsnorm_rope` to support `cos_sin_cache` && `positions` as inputs. This fully aligns to vLLM RoPE api interface. Compared with earlier implementation for RoPE, the benefits are: 1. avoiding pre-computation of `cos` `sin` before model execution, which helps to remove redundant codes. 2. allowing eagle3 draft model to have different rope parameters with main model (see #6612 ). This help to recover accept rate && accuracy in that case. In addition, this kernel change only introduces very small performance degradation. Those `index_select` or `chunk` operations are now changed into simple memory access in triton kernel (For example, https://github.com/vllm-project/vllm-ascend/pull/5450/changes#diff-a4c2d3071530df193b98f9bf38553874bc4d47571336711f116c26d019cfbb6aR77-R81). **Highlights** - **RoPE Cache Unification**: Replaced separate _sin and _cos global tensors with a unified cos_sin_cache and explicit positions tensor for Rotary Positional Embeddings (RoPE), streamlining data handling. - **Triton Kernel Integration**: Updated Triton kernels (split_qkv_rmsnorm_rope_kernel, _triton_rope) to directly consume the cos_sin_cache and positions for more efficient and integrated RoPE calculations. - **Custom Operation Registration**: Registered `rope_forward_oot` as a new custom operation, allowing its use in fused compilation passes and providing a dedicated entry point for the new RoPE implementation. - **Refactored RoPE Forward Pass**: Modified the rope_forward_oot function to accept the new cos_sin_cache and positions arguments, enabling a more flexible and integrated RoPE application within the system. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5326c89803566a131c928f7fdd2100b75c981a42 Additional test on Qwen3-235b accuracy: | Aime2024 | GSM8K | Livecodebench | | -------- | -------- | -------- | | 83.33 | 96.26 | 70.23 | --------- Signed-off-by: Angazenn <supperccell@163.com>
2026-02-11 21:20:53 +08:00
if dtype not in (torch.bfloat16,):
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
logger.debug("QKNorm and Rope fusion not enabled: unsupported dtype %s", dtype)
return
# use one attn layer to get meta (such as head_dim) for QKNormRopeFusionPattern
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
attn_layers: dict[str, Attention] = get_layers_from_vllm_config(vllm_config, Attention)
if len(attn_layers) == 0:
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
logger.debug("QKNorm and Rope fusion enabled, but no Attention layers were discovered.")
return
layer = next(iter(attn_layers.values()))
for epsilon in [1e-6, 1e-5]:
if layer.head_size != 128:
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
logger.debug("QKNorm and Rope fusion not enabled: head_dim %d is not equal of 128", layer.head_size)
continue
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
QKNormRopeFusionPattern(
vllm_config=vllm_config,
head_dim=layer.head_size,
num_heads=layer.num_heads,
num_kv_heads=layer.num_kv_heads,
eps=epsilon,
).register(self.pattern_match_passes)
QKNormRopeFusionPatternWithBias(
vllm_config=vllm_config,
head_dim=layer.head_size,
num_heads=layer.num_heads,
num_kv_heads=layer.num_kv_heads,
eps=epsilon,
).register(self.pattern_match_passes)
def __call__(self, graph: torch.fx.Graph):
self.begin()
self.matched_count = self.pattern_match_passes.apply(graph)
logger.debug("Fused %s QKNorm and Rope patterns", self.matched_count)
logger.debug("Patterns registered for replacement:")
pattern_idx = 0
for pattern_entry in self.pattern_match_passes.patterns.values():
for p in pattern_entry:
p_str = PatternPrettyPrinter.run(p.pattern)
logger.debug("Pattern %d: %s", pattern_idx, p_str)
pattern_idx += 1
self.end_and_log()
def is_applicable_for_range(self, compile_range: Range) -> bool:
"""
Check if the pass is applicable for the current configuration.
"""
return True