Files
xc-llm-ascend/vllm_ascend/compilation/compiler_interface.py

140 lines
5.4 KiB
Python
Raw Normal View History

Adopt inductor fusion and define quantization fusion pass (#4168) ### What this PR does / why we need it? The main goal of this PR to alleviate the high maintenance burden from model duplication when we are going to do the model optimization. Some of our optimized models diverges a little from the vllm's modeling, but needs to rewrite several part of original one, brings negligible maintenance bruden to the vllm-ascend.In order to solve that, we propose to leverage `torch.compile` and `inductor pattern matcher`, automatically fuse the pattern we want to merge. For more details can refer to the RFC https://github.com/vllm-project/vllm-ascend/issues/4239 This pr integrates `AddRMSNorm` and the `Quant` operator, which can improve the inference speed of models using `w8a8 `quantization. ### Does this PR introduce _any_ user-facing change? Yes, add new additional_config ### How was this patch tested? ```python def main(): prompts = [ "The president of the United States is Mr.", ] # Create a sampling params object. sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95) # Create an LLM. llm = LLM( model="/root/.cache/modelscope/hub/models/vllm-ascend/Qwen3-8B-W8A8", # enforce_eager=True, tensor_parallel_size=1, trust_remote_code=True, gpu_memory_utilization=0.7, quantization="ascend", ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` ```text Prompt: 'The president of the United States is Mr.', Generated text: ' Trump. The president of the United States is Mr. Biden. Which of the following statements is correct? \n\nA. Mr. Trump is Mr. Biden. \nB. Mr. Trump is not Mr. Biden. \nC. The president of the United States is not Mr. Trump. \nD. The president of the United States is not Mr. Biden.\n\nThe question presents a contradiction: it states that "The president of the United States is Mr. Trump" and "The president of' ``` - vLLM version: 86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 - vLLM main: https://github.com/vllm-project/vllm/commit/86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 --------- Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: wxsIcey <1790571317@qq.com>
2025-12-04 10:29:48 +08:00
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import functools
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
from collections.abc import Callable
from typing import Any
Adopt inductor fusion and define quantization fusion pass (#4168) ### What this PR does / why we need it? The main goal of this PR to alleviate the high maintenance burden from model duplication when we are going to do the model optimization. Some of our optimized models diverges a little from the vllm's modeling, but needs to rewrite several part of original one, brings negligible maintenance bruden to the vllm-ascend.In order to solve that, we propose to leverage `torch.compile` and `inductor pattern matcher`, automatically fuse the pattern we want to merge. For more details can refer to the RFC https://github.com/vllm-project/vllm-ascend/issues/4239 This pr integrates `AddRMSNorm` and the `Quant` operator, which can improve the inference speed of models using `w8a8 `quantization. ### Does this PR introduce _any_ user-facing change? Yes, add new additional_config ### How was this patch tested? ```python def main(): prompts = [ "The president of the United States is Mr.", ] # Create a sampling params object. sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95) # Create an LLM. llm = LLM( model="/root/.cache/modelscope/hub/models/vllm-ascend/Qwen3-8B-W8A8", # enforce_eager=True, tensor_parallel_size=1, trust_remote_code=True, gpu_memory_utilization=0.7, quantization="ascend", ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` ```text Prompt: 'The president of the United States is Mr.', Generated text: ' Trump. The president of the United States is Mr. Biden. Which of the following statements is correct? \n\nA. Mr. Trump is Mr. Biden. \nB. Mr. Trump is not Mr. Biden. \nC. The president of the United States is not Mr. Trump. \nD. The president of the United States is not Mr. Biden.\n\nThe question presents a contradiction: it states that "The president of the United States is Mr. Trump" and "The president of' ``` - vLLM version: 86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 - vLLM main: https://github.com/vllm-project/vllm/commit/86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 --------- Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: wxsIcey <1790571317@qq.com>
2025-12-04 10:29:48 +08:00
[Feature] Support npuhraph_ex backend (#4700) ### What this PR does / why we need it? We introduced the npugraph_ex backend through the vllm's adaptor dispatch mechanism to accelerate aclgraph. This solution is based on torch.compile and uses torchair to optimize the fx.graph. The performance gains are mainly obtained from the static kernel. We conducted tests on Qwen3-30B and achieved over 5% performance optimization. ### Does this PR introduce _any_ user-facing change? Yes, we add a new switch named"enable_npugraph_ex" in additional_config, default is False. We also add an example to show how to register custom replacement pass ### More information about this PR This feature depends on the release of CANN and torch_npu in Q4. We tested it on a package that has not been publicly released yet and verified that the functionality works. This feature is still experimental at the moment; setting the config true will directly raise error. Merging into the main branch initially involves some preliminary commits to facilitate subsequent development and testing of the feature, as well as to avoid submitting an excessively large PR at once. - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: chencangtao <chencangtao@huawei.com> Signed-off-by: ChenCangtao <50493711+ChenCangtao@users.noreply.github.com> Co-authored-by: chencangtao <chencangtao@huawei.com> Co-authored-by: panchao-hub <315134829@qq.com> Co-authored-by: wbigat <wbigat@163.com> Co-authored-by: Mengqing Cao <cmq0113@163.com>
2025-12-10 20:48:05 +08:00
import torch
Adopt inductor fusion and define quantization fusion pass (#4168) ### What this PR does / why we need it? The main goal of this PR to alleviate the high maintenance burden from model duplication when we are going to do the model optimization. Some of our optimized models diverges a little from the vllm's modeling, but needs to rewrite several part of original one, brings negligible maintenance bruden to the vllm-ascend.In order to solve that, we propose to leverage `torch.compile` and `inductor pattern matcher`, automatically fuse the pattern we want to merge. For more details can refer to the RFC https://github.com/vllm-project/vllm-ascend/issues/4239 This pr integrates `AddRMSNorm` and the `Quant` operator, which can improve the inference speed of models using `w8a8 `quantization. ### Does this PR introduce _any_ user-facing change? Yes, add new additional_config ### How was this patch tested? ```python def main(): prompts = [ "The president of the United States is Mr.", ] # Create a sampling params object. sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95) # Create an LLM. llm = LLM( model="/root/.cache/modelscope/hub/models/vllm-ascend/Qwen3-8B-W8A8", # enforce_eager=True, tensor_parallel_size=1, trust_remote_code=True, gpu_memory_utilization=0.7, quantization="ascend", ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` ```text Prompt: 'The president of the United States is Mr.', Generated text: ' Trump. The president of the United States is Mr. Biden. Which of the following statements is correct? \n\nA. Mr. Trump is Mr. Biden. \nB. Mr. Trump is not Mr. Biden. \nC. The president of the United States is not Mr. Trump. \nD. The president of the United States is not Mr. Biden.\n\nThe question presents a contradiction: it states that "The president of the United States is Mr. Trump" and "The president of' ``` - vLLM version: 86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 - vLLM main: https://github.com/vllm-project/vllm/commit/86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 --------- Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: wxsIcey <1790571317@qq.com>
2025-12-04 10:29:48 +08:00
import torch.fx as fx
from torch._dynamo.backends.common import aot_autograd
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
from torch._inductor.compile_fx import graph_returns_tuple, make_graph_return_tuple
Adopt inductor fusion and define quantization fusion pass (#4168) ### What this PR does / why we need it? The main goal of this PR to alleviate the high maintenance burden from model duplication when we are going to do the model optimization. Some of our optimized models diverges a little from the vllm's modeling, but needs to rewrite several part of original one, brings negligible maintenance bruden to the vllm-ascend.In order to solve that, we propose to leverage `torch.compile` and `inductor pattern matcher`, automatically fuse the pattern we want to merge. For more details can refer to the RFC https://github.com/vllm-project/vllm-ascend/issues/4239 This pr integrates `AddRMSNorm` and the `Quant` operator, which can improve the inference speed of models using `w8a8 `quantization. ### Does this PR introduce _any_ user-facing change? Yes, add new additional_config ### How was this patch tested? ```python def main(): prompts = [ "The president of the United States is Mr.", ] # Create a sampling params object. sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95) # Create an LLM. llm = LLM( model="/root/.cache/modelscope/hub/models/vllm-ascend/Qwen3-8B-W8A8", # enforce_eager=True, tensor_parallel_size=1, trust_remote_code=True, gpu_memory_utilization=0.7, quantization="ascend", ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` ```text Prompt: 'The president of the United States is Mr.', Generated text: ' Trump. The president of the United States is Mr. Biden. Which of the following statements is correct? \n\nA. Mr. Trump is Mr. Biden. \nB. Mr. Trump is not Mr. Biden. \nC. The president of the United States is not Mr. Trump. \nD. The president of the United States is not Mr. Biden.\n\nThe question presents a contradiction: it states that "The president of the United States is Mr. Trump" and "The president of' ``` - vLLM version: 86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 - vLLM main: https://github.com/vllm-project/vllm/commit/86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 --------- Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: wxsIcey <1790571317@qq.com>
2025-12-04 10:29:48 +08:00
from torch._inductor.decomposition import select_decomp_table
from torch.fx import GraphModule
from vllm.compilation.compiler_interface import CompilerInterface
[Feature]refactor the npugraph_ex config, support online-infer with static kernel (#5775) ### What this PR does / why we need it? This is a part of https://github.com/vllm-project/vllm-ascend/issues/4715#issue-3694310762 1. refactor the npugraph_ex config,modified the default configuration of the static kernel, new default value of static kernel is false 2. support online-infer with static kernel 3. fixed the issue where manually modifying FX graphs caused an abnormal model return type, and removed the related redundant code. ### Does this PR introduce _any_ user-facing change? yes,the new config of npugraph_ex is as follow: ``` additional_config={ "npugraph_ex_config": { "enable": True, "enable_static_kernel": False } } ``` ### How was this patch tested? ``` vllm serve /data/DeepSeek-V3.1-Terminus-w4a8 \ --host 0.0.0.0 \ --port 8004 \ --data-parallel-size 4 \ --tensor-parallel-size 4 \ --quantization ascend \ --seed 1024 \ --served-model-name deepseek_v3 \ --enable-expert-parallel \ --max-num-seqs 48 \ --max-model-len 40000 \ --async-scheduling \ --max-num-batched-tokens 9000 \ --trust-remote-code \ --no-enable-prefix-caching \ --speculative-config '{"num_speculative_tokens": 3, "method":"deepseek_mtp","disable_padded_drafter_batch": false}' \ --gpu-memory-utilization 0.9 \ --compilation-config '{"cudagraph_capture_sizes":[4,32,64,112,160,176,192], "cudagraph_mode": "FULL_DECODE_ONLY"}' \ --additional-config \ '{"enable_shared_expert_dp": true,"multistream_overlap_shared_expert": true,"npugraph_ex_config":{"enable":true}}' ``` - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2f4e6548efec402b913ffddc8726230d9311948d --------- Signed-off-by: chencangtao <chencangtao@huawei.com> Signed-off-by: ChenCangtao <50493711+ChenCangtao@users.noreply.github.com> Co-authored-by: chencangtao <chencangtao@huawei.com>
2026-01-20 21:31:38 +08:00
from vllm.config import VllmConfig
from vllm.config.utils import Range
Adopt inductor fusion and define quantization fusion pass (#4168) ### What this PR does / why we need it? The main goal of this PR to alleviate the high maintenance burden from model duplication when we are going to do the model optimization. Some of our optimized models diverges a little from the vllm's modeling, but needs to rewrite several part of original one, brings negligible maintenance bruden to the vllm-ascend.In order to solve that, we propose to leverage `torch.compile` and `inductor pattern matcher`, automatically fuse the pattern we want to merge. For more details can refer to the RFC https://github.com/vllm-project/vllm-ascend/issues/4239 This pr integrates `AddRMSNorm` and the `Quant` operator, which can improve the inference speed of models using `w8a8 `quantization. ### Does this PR introduce _any_ user-facing change? Yes, add new additional_config ### How was this patch tested? ```python def main(): prompts = [ "The president of the United States is Mr.", ] # Create a sampling params object. sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95) # Create an LLM. llm = LLM( model="/root/.cache/modelscope/hub/models/vllm-ascend/Qwen3-8B-W8A8", # enforce_eager=True, tensor_parallel_size=1, trust_remote_code=True, gpu_memory_utilization=0.7, quantization="ascend", ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` ```text Prompt: 'The president of the United States is Mr.', Generated text: ' Trump. The president of the United States is Mr. Biden. Which of the following statements is correct? \n\nA. Mr. Trump is Mr. Biden. \nB. Mr. Trump is not Mr. Biden. \nC. The president of the United States is not Mr. Trump. \nD. The president of the United States is not Mr. Biden.\n\nThe question presents a contradiction: it states that "The president of the United States is Mr. Trump" and "The president of' ``` - vLLM version: 86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 - vLLM main: https://github.com/vllm-project/vllm/commit/86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 --------- Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: wxsIcey <1790571317@qq.com>
2025-12-04 10:29:48 +08:00
[Feature]refactor the npugraph_ex config, support online-infer with static kernel (#5775) ### What this PR does / why we need it? This is a part of https://github.com/vllm-project/vllm-ascend/issues/4715#issue-3694310762 1. refactor the npugraph_ex config,modified the default configuration of the static kernel, new default value of static kernel is false 2. support online-infer with static kernel 3. fixed the issue where manually modifying FX graphs caused an abnormal model return type, and removed the related redundant code. ### Does this PR introduce _any_ user-facing change? yes,the new config of npugraph_ex is as follow: ``` additional_config={ "npugraph_ex_config": { "enable": True, "enable_static_kernel": False } } ``` ### How was this patch tested? ``` vllm serve /data/DeepSeek-V3.1-Terminus-w4a8 \ --host 0.0.0.0 \ --port 8004 \ --data-parallel-size 4 \ --tensor-parallel-size 4 \ --quantization ascend \ --seed 1024 \ --served-model-name deepseek_v3 \ --enable-expert-parallel \ --max-num-seqs 48 \ --max-model-len 40000 \ --async-scheduling \ --max-num-batched-tokens 9000 \ --trust-remote-code \ --no-enable-prefix-caching \ --speculative-config '{"num_speculative_tokens": 3, "method":"deepseek_mtp","disable_padded_drafter_batch": false}' \ --gpu-memory-utilization 0.9 \ --compilation-config '{"cudagraph_capture_sizes":[4,32,64,112,160,176,192], "cudagraph_mode": "FULL_DECODE_ONLY"}' \ --additional-config \ '{"enable_shared_expert_dp": true,"multistream_overlap_shared_expert": true,"npugraph_ex_config":{"enable":true}}' ``` - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2f4e6548efec402b913ffddc8726230d9311948d --------- Signed-off-by: chencangtao <chencangtao@huawei.com> Signed-off-by: ChenCangtao <50493711+ChenCangtao@users.noreply.github.com> Co-authored-by: chencangtao <chencangtao@huawei.com>
2026-01-20 21:31:38 +08:00
from vllm_ascend.ascend_config import NpugraphExConfig, get_ascend_config
from vllm_ascend.utils import COMPILATION_PASS_KEY
[Feature] Support npuhraph_ex backend (#4700) ### What this PR does / why we need it? We introduced the npugraph_ex backend through the vllm's adaptor dispatch mechanism to accelerate aclgraph. This solution is based on torch.compile and uses torchair to optimize the fx.graph. The performance gains are mainly obtained from the static kernel. We conducted tests on Qwen3-30B and achieved over 5% performance optimization. ### Does this PR introduce _any_ user-facing change? Yes, we add a new switch named"enable_npugraph_ex" in additional_config, default is False. We also add an example to show how to register custom replacement pass ### More information about this PR This feature depends on the release of CANN and torch_npu in Q4. We tested it on a package that has not been publicly released yet and verified that the functionality works. This feature is still experimental at the moment; setting the config true will directly raise error. Merging into the main branch initially involves some preliminary commits to facilitate subsequent development and testing of the feature, as well as to avoid submitting an excessively large PR at once. - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: chencangtao <chencangtao@huawei.com> Signed-off-by: ChenCangtao <50493711+ChenCangtao@users.noreply.github.com> Co-authored-by: chencangtao <chencangtao@huawei.com> Co-authored-by: panchao-hub <315134829@qq.com> Co-authored-by: wbigat <wbigat@163.com> Co-authored-by: Mengqing Cao <cmq0113@163.com>
2025-12-10 20:48:05 +08:00
Adopt inductor fusion and define quantization fusion pass (#4168) ### What this PR does / why we need it? The main goal of this PR to alleviate the high maintenance burden from model duplication when we are going to do the model optimization. Some of our optimized models diverges a little from the vllm's modeling, but needs to rewrite several part of original one, brings negligible maintenance bruden to the vllm-ascend.In order to solve that, we propose to leverage `torch.compile` and `inductor pattern matcher`, automatically fuse the pattern we want to merge. For more details can refer to the RFC https://github.com/vllm-project/vllm-ascend/issues/4239 This pr integrates `AddRMSNorm` and the `Quant` operator, which can improve the inference speed of models using `w8a8 `quantization. ### Does this PR introduce _any_ user-facing change? Yes, add new additional_config ### How was this patch tested? ```python def main(): prompts = [ "The president of the United States is Mr.", ] # Create a sampling params object. sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95) # Create an LLM. llm = LLM( model="/root/.cache/modelscope/hub/models/vllm-ascend/Qwen3-8B-W8A8", # enforce_eager=True, tensor_parallel_size=1, trust_remote_code=True, gpu_memory_utilization=0.7, quantization="ascend", ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` ```text Prompt: 'The president of the United States is Mr.', Generated text: ' Trump. The president of the United States is Mr. Biden. Which of the following statements is correct? \n\nA. Mr. Trump is Mr. Biden. \nB. Mr. Trump is not Mr. Biden. \nC. The president of the United States is not Mr. Trump. \nD. The president of the United States is not Mr. Biden.\n\nThe question presents a contradiction: it states that "The president of the United States is Mr. Trump" and "The president of' ``` - vLLM version: 86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 - vLLM main: https://github.com/vllm-project/vllm/commit/86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 --------- Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: wxsIcey <1790571317@qq.com>
2025-12-04 10:29:48 +08:00
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
def compile_fx(graph: GraphModule, example_inputs: list, inner_compile: Callable, decompositions: dict) -> Callable:
recursive_compile_fx = functools.partial(compile_fx, inner_compile=inner_compile, decompositions=decompositions)
Adopt inductor fusion and define quantization fusion pass (#4168) ### What this PR does / why we need it? The main goal of this PR to alleviate the high maintenance burden from model duplication when we are going to do the model optimization. Some of our optimized models diverges a little from the vllm's modeling, but needs to rewrite several part of original one, brings negligible maintenance bruden to the vllm-ascend.In order to solve that, we propose to leverage `torch.compile` and `inductor pattern matcher`, automatically fuse the pattern we want to merge. For more details can refer to the RFC https://github.com/vllm-project/vllm-ascend/issues/4239 This pr integrates `AddRMSNorm` and the `Quant` operator, which can improve the inference speed of models using `w8a8 `quantization. ### Does this PR introduce _any_ user-facing change? Yes, add new additional_config ### How was this patch tested? ```python def main(): prompts = [ "The president of the United States is Mr.", ] # Create a sampling params object. sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95) # Create an LLM. llm = LLM( model="/root/.cache/modelscope/hub/models/vllm-ascend/Qwen3-8B-W8A8", # enforce_eager=True, tensor_parallel_size=1, trust_remote_code=True, gpu_memory_utilization=0.7, quantization="ascend", ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` ```text Prompt: 'The president of the United States is Mr.', Generated text: ' Trump. The president of the United States is Mr. Biden. Which of the following statements is correct? \n\nA. Mr. Trump is Mr. Biden. \nB. Mr. Trump is not Mr. Biden. \nC. The president of the United States is not Mr. Trump. \nD. The president of the United States is not Mr. Biden.\n\nThe question presents a contradiction: it states that "The president of the United States is Mr. Trump" and "The president of' ``` - vLLM version: 86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 - vLLM main: https://github.com/vllm-project/vllm/commit/86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 --------- Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: wxsIcey <1790571317@qq.com>
2025-12-04 10:29:48 +08:00
if not graph_returns_tuple(graph):
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
return make_graph_return_tuple(graph, example_inputs, recursive_compile_fx)
Adopt inductor fusion and define quantization fusion pass (#4168) ### What this PR does / why we need it? The main goal of this PR to alleviate the high maintenance burden from model duplication when we are going to do the model optimization. Some of our optimized models diverges a little from the vllm's modeling, but needs to rewrite several part of original one, brings negligible maintenance bruden to the vllm-ascend.In order to solve that, we propose to leverage `torch.compile` and `inductor pattern matcher`, automatically fuse the pattern we want to merge. For more details can refer to the RFC https://github.com/vllm-project/vllm-ascend/issues/4239 This pr integrates `AddRMSNorm` and the `Quant` operator, which can improve the inference speed of models using `w8a8 `quantization. ### Does this PR introduce _any_ user-facing change? Yes, add new additional_config ### How was this patch tested? ```python def main(): prompts = [ "The president of the United States is Mr.", ] # Create a sampling params object. sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95) # Create an LLM. llm = LLM( model="/root/.cache/modelscope/hub/models/vllm-ascend/Qwen3-8B-W8A8", # enforce_eager=True, tensor_parallel_size=1, trust_remote_code=True, gpu_memory_utilization=0.7, quantization="ascend", ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` ```text Prompt: 'The president of the United States is Mr.', Generated text: ' Trump. The president of the United States is Mr. Biden. Which of the following statements is correct? \n\nA. Mr. Trump is Mr. Biden. \nB. Mr. Trump is not Mr. Biden. \nC. The president of the United States is not Mr. Trump. \nD. The president of the United States is not Mr. Biden.\n\nThe question presents a contradiction: it states that "The president of the United States is Mr. Trump" and "The president of' ``` - vLLM version: 86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 - vLLM main: https://github.com/vllm-project/vllm/commit/86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 --------- Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: wxsIcey <1790571317@qq.com>
2025-12-04 10:29:48 +08:00
return aot_autograd(fw_compiler=inner_compile)(graph, example_inputs)
[Feature] Support npuhraph_ex backend (#4700) ### What this PR does / why we need it? We introduced the npugraph_ex backend through the vllm's adaptor dispatch mechanism to accelerate aclgraph. This solution is based on torch.compile and uses torchair to optimize the fx.graph. The performance gains are mainly obtained from the static kernel. We conducted tests on Qwen3-30B and achieved over 5% performance optimization. ### Does this PR introduce _any_ user-facing change? Yes, we add a new switch named"enable_npugraph_ex" in additional_config, default is False. We also add an example to show how to register custom replacement pass ### More information about this PR This feature depends on the release of CANN and torch_npu in Q4. We tested it on a package that has not been publicly released yet and verified that the functionality works. This feature is still experimental at the moment; setting the config true will directly raise error. Merging into the main branch initially involves some preliminary commits to facilitate subsequent development and testing of the feature, as well as to avoid submitting an excessively large PR at once. - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: chencangtao <chencangtao@huawei.com> Signed-off-by: ChenCangtao <50493711+ChenCangtao@users.noreply.github.com> Co-authored-by: chencangtao <chencangtao@huawei.com> Co-authored-by: panchao-hub <315134829@qq.com> Co-authored-by: wbigat <wbigat@163.com> Co-authored-by: Mengqing Cao <cmq0113@163.com>
2025-12-10 20:48:05 +08:00
def fusion_pass_compile(
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
compile_range: Range,
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
[Feature] Support npuhraph_ex backend (#4700) ### What this PR does / why we need it? We introduced the npugraph_ex backend through the vllm's adaptor dispatch mechanism to accelerate aclgraph. This solution is based on torch.compile and uses torchair to optimize the fx.graph. The performance gains are mainly obtained from the static kernel. We conducted tests on Qwen3-30B and achieved over 5% performance optimization. ### Does this PR introduce _any_ user-facing change? Yes, we add a new switch named"enable_npugraph_ex" in additional_config, default is False. We also add an example to show how to register custom replacement pass ### More information about this PR This feature depends on the release of CANN and torch_npu in Q4. We tested it on a package that has not been publicly released yet and verified that the functionality works. This feature is still experimental at the moment; setting the config true will directly raise error. Merging into the main branch initially involves some preliminary commits to facilitate subsequent development and testing of the feature, as well as to avoid submitting an excessively large PR at once. - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: chencangtao <chencangtao@huawei.com> Signed-off-by: ChenCangtao <50493711+ChenCangtao@users.noreply.github.com> Co-authored-by: chencangtao <chencangtao@huawei.com> Co-authored-by: panchao-hub <315134829@qq.com> Co-authored-by: wbigat <wbigat@163.com> Co-authored-by: Mengqing Cao <cmq0113@163.com>
2025-12-10 20:48:05 +08:00
def compile_inner(graph, example_inputs):
current_pass_manager = compiler_config[COMPILATION_PASS_KEY]
graph = current_pass_manager(graph)
[Feature] Support npuhraph_ex backend (#4700) ### What this PR does / why we need it? We introduced the npugraph_ex backend through the vllm's adaptor dispatch mechanism to accelerate aclgraph. This solution is based on torch.compile and uses torchair to optimize the fx.graph. The performance gains are mainly obtained from the static kernel. We conducted tests on Qwen3-30B and achieved over 5% performance optimization. ### Does this PR introduce _any_ user-facing change? Yes, we add a new switch named"enable_npugraph_ex" in additional_config, default is False. We also add an example to show how to register custom replacement pass ### More information about this PR This feature depends on the release of CANN and torch_npu in Q4. We tested it on a package that has not been publicly released yet and verified that the functionality works. This feature is still experimental at the moment; setting the config true will directly raise error. Merging into the main branch initially involves some preliminary commits to facilitate subsequent development and testing of the feature, as well as to avoid submitting an excessively large PR at once. - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: chencangtao <chencangtao@huawei.com> Signed-off-by: ChenCangtao <50493711+ChenCangtao@users.noreply.github.com> Co-authored-by: chencangtao <chencangtao@huawei.com> Co-authored-by: panchao-hub <315134829@qq.com> Co-authored-by: wbigat <wbigat@163.com> Co-authored-by: Mengqing Cao <cmq0113@163.com>
2025-12-10 20:48:05 +08:00
return graph
decompositions = select_decomp_table()
compiled_fn = compile_fx(
graph=graph,
example_inputs=example_inputs,
inner_compile=compile_inner,
decompositions=decompositions,
)
return compiled_fn, None
def npugraph_ex_compile(
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
[Feature]refactor the npugraph_ex config, support online-infer with static kernel (#5775) ### What this PR does / why we need it? This is a part of https://github.com/vllm-project/vllm-ascend/issues/4715#issue-3694310762 1. refactor the npugraph_ex config,modified the default configuration of the static kernel, new default value of static kernel is false 2. support online-infer with static kernel 3. fixed the issue where manually modifying FX graphs caused an abnormal model return type, and removed the related redundant code. ### Does this PR introduce _any_ user-facing change? yes,the new config of npugraph_ex is as follow: ``` additional_config={ "npugraph_ex_config": { "enable": True, "enable_static_kernel": False } } ``` ### How was this patch tested? ``` vllm serve /data/DeepSeek-V3.1-Terminus-w4a8 \ --host 0.0.0.0 \ --port 8004 \ --data-parallel-size 4 \ --tensor-parallel-size 4 \ --quantization ascend \ --seed 1024 \ --served-model-name deepseek_v3 \ --enable-expert-parallel \ --max-num-seqs 48 \ --max-model-len 40000 \ --async-scheduling \ --max-num-batched-tokens 9000 \ --trust-remote-code \ --no-enable-prefix-caching \ --speculative-config '{"num_speculative_tokens": 3, "method":"deepseek_mtp","disable_padded_drafter_batch": false}' \ --gpu-memory-utilization 0.9 \ --compilation-config '{"cudagraph_capture_sizes":[4,32,64,112,160,176,192], "cudagraph_mode": "FULL_DECODE_ONLY"}' \ --additional-config \ '{"enable_shared_expert_dp": true,"multistream_overlap_shared_expert": true,"npugraph_ex_config":{"enable":true}}' ``` - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2f4e6548efec402b913ffddc8726230d9311948d --------- Signed-off-by: chencangtao <chencangtao@huawei.com> Signed-off-by: ChenCangtao <50493711+ChenCangtao@users.noreply.github.com> Co-authored-by: chencangtao <chencangtao@huawei.com>
2026-01-20 21:31:38 +08:00
vllm_config: VllmConfig,
npugraph_ex_config: NpugraphExConfig,
compile_range: Range,
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
[Feature] Support npuhraph_ex backend (#4700) ### What this PR does / why we need it? We introduced the npugraph_ex backend through the vllm's adaptor dispatch mechanism to accelerate aclgraph. This solution is based on torch.compile and uses torchair to optimize the fx.graph. The performance gains are mainly obtained from the static kernel. We conducted tests on Qwen3-30B and achieved over 5% performance optimization. ### Does this PR introduce _any_ user-facing change? Yes, we add a new switch named"enable_npugraph_ex" in additional_config, default is False. We also add an example to show how to register custom replacement pass ### More information about this PR This feature depends on the release of CANN and torch_npu in Q4. We tested it on a package that has not been publicly released yet and verified that the functionality works. This feature is still experimental at the moment; setting the config true will directly raise error. Merging into the main branch initially involves some preliminary commits to facilitate subsequent development and testing of the feature, as well as to avoid submitting an excessively large PR at once. - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: chencangtao <chencangtao@huawei.com> Signed-off-by: ChenCangtao <50493711+ChenCangtao@users.noreply.github.com> Co-authored-by: chencangtao <chencangtao@huawei.com> Co-authored-by: panchao-hub <315134829@qq.com> Co-authored-by: wbigat <wbigat@163.com> Co-authored-by: Mengqing Cao <cmq0113@163.com>
2025-12-10 20:48:05 +08:00
import torchair
torch.npu.set_compile_mode(jit_compile=False)
config = torchair.CompilerConfig()
# use aclgraph mode, avoid the transformation from fx graph to Ascend IR.
config.mode = "reduce-overhead"
# execute FX graph in eager mode before graph mode to optimize FX graph.
config.debug.run_eagerly = True
[Feature]refactor the npugraph_ex config, support online-infer with static kernel (#5775) ### What this PR does / why we need it? This is a part of https://github.com/vllm-project/vllm-ascend/issues/4715#issue-3694310762 1. refactor the npugraph_ex config,modified the default configuration of the static kernel, new default value of static kernel is false 2. support online-infer with static kernel 3. fixed the issue where manually modifying FX graphs caused an abnormal model return type, and removed the related redundant code. ### Does this PR introduce _any_ user-facing change? yes,the new config of npugraph_ex is as follow: ``` additional_config={ "npugraph_ex_config": { "enable": True, "enable_static_kernel": False } } ``` ### How was this patch tested? ``` vllm serve /data/DeepSeek-V3.1-Terminus-w4a8 \ --host 0.0.0.0 \ --port 8004 \ --data-parallel-size 4 \ --tensor-parallel-size 4 \ --quantization ascend \ --seed 1024 \ --served-model-name deepseek_v3 \ --enable-expert-parallel \ --max-num-seqs 48 \ --max-model-len 40000 \ --async-scheduling \ --max-num-batched-tokens 9000 \ --trust-remote-code \ --no-enable-prefix-caching \ --speculative-config '{"num_speculative_tokens": 3, "method":"deepseek_mtp","disable_padded_drafter_batch": false}' \ --gpu-memory-utilization 0.9 \ --compilation-config '{"cudagraph_capture_sizes":[4,32,64,112,160,176,192], "cudagraph_mode": "FULL_DECODE_ONLY"}' \ --additional-config \ '{"enable_shared_expert_dp": true,"multistream_overlap_shared_expert": true,"npugraph_ex_config":{"enable":true}}' ``` - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2f4e6548efec402b913ffddc8726230d9311948d --------- Signed-off-by: chencangtao <chencangtao@huawei.com> Signed-off-by: ChenCangtao <50493711+ChenCangtao@users.noreply.github.com> Co-authored-by: chencangtao <chencangtao@huawei.com>
2026-01-20 21:31:38 +08:00
if npugraph_ex_config.enable_static_kernel:
config.experimental_config.aclgraph._aclnn_static_shape_kernel = True
# According to the cudagraph_capture_size configuration, set the shapes
# that can trigger the compilation of static kernel. If this configuration is
# not applied, new shapes will trigger the compilation of static kernels,
# affecting program execution.
num_spec_tokens = vllm_config.speculative_config.num_speculative_token if vllm_config.speculative_config else 0
uniform_decode_query_len = num_spec_tokens + 1
max_num_tokens = vllm_config.scheduler_config.max_num_seq * uniform_decode_query_len
decode_cudagraph_batch_sizes = [
x
for x in vllm_config.compilation_config.cudagraph_capture_size
if max_num_tokens >= x >= uniform_decode_query_len
]
config.experimental_config.aclgraph._aclnn_static_shape_kernel_sym_value_range = decode_cudagraph_batch_sizes
[Feature] Support npuhraph_ex backend (#4700) ### What this PR does / why we need it? We introduced the npugraph_ex backend through the vllm's adaptor dispatch mechanism to accelerate aclgraph. This solution is based on torch.compile and uses torchair to optimize the fx.graph. The performance gains are mainly obtained from the static kernel. We conducted tests on Qwen3-30B and achieved over 5% performance optimization. ### Does this PR introduce _any_ user-facing change? Yes, we add a new switch named"enable_npugraph_ex" in additional_config, default is False. We also add an example to show how to register custom replacement pass ### More information about this PR This feature depends on the release of CANN and torch_npu in Q4. We tested it on a package that has not been publicly released yet and verified that the functionality works. This feature is still experimental at the moment; setting the config true will directly raise error. Merging into the main branch initially involves some preliminary commits to facilitate subsequent development and testing of the feature, as well as to avoid submitting an excessively large PR at once. - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: chencangtao <chencangtao@huawei.com> Signed-off-by: ChenCangtao <50493711+ChenCangtao@users.noreply.github.com> Co-authored-by: chencangtao <chencangtao@huawei.com> Co-authored-by: panchao-hub <315134829@qq.com> Co-authored-by: wbigat <wbigat@163.com> Co-authored-by: Mengqing Cao <cmq0113@163.com>
2025-12-10 20:48:05 +08:00
npugraph_ex = torchair.get_npu_backend(compiler_config=config)
[Feature]refactor the npugraph_ex config, support online-infer with static kernel (#5775) ### What this PR does / why we need it? This is a part of https://github.com/vllm-project/vllm-ascend/issues/4715#issue-3694310762 1. refactor the npugraph_ex config,modified the default configuration of the static kernel, new default value of static kernel is false 2. support online-infer with static kernel 3. fixed the issue where manually modifying FX graphs caused an abnormal model return type, and removed the related redundant code. ### Does this PR introduce _any_ user-facing change? yes,the new config of npugraph_ex is as follow: ``` additional_config={ "npugraph_ex_config": { "enable": True, "enable_static_kernel": False } } ``` ### How was this patch tested? ``` vllm serve /data/DeepSeek-V3.1-Terminus-w4a8 \ --host 0.0.0.0 \ --port 8004 \ --data-parallel-size 4 \ --tensor-parallel-size 4 \ --quantization ascend \ --seed 1024 \ --served-model-name deepseek_v3 \ --enable-expert-parallel \ --max-num-seqs 48 \ --max-model-len 40000 \ --async-scheduling \ --max-num-batched-tokens 9000 \ --trust-remote-code \ --no-enable-prefix-caching \ --speculative-config '{"num_speculative_tokens": 3, "method":"deepseek_mtp","disable_padded_drafter_batch": false}' \ --gpu-memory-utilization 0.9 \ --compilation-config '{"cudagraph_capture_sizes":[4,32,64,112,160,176,192], "cudagraph_mode": "FULL_DECODE_ONLY"}' \ --additional-config \ '{"enable_shared_expert_dp": true,"multistream_overlap_shared_expert": true,"npugraph_ex_config":{"enable":true}}' ``` - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2f4e6548efec402b913ffddc8726230d9311948d --------- Signed-off-by: chencangtao <chencangtao@huawei.com> Signed-off-by: ChenCangtao <50493711+ChenCangtao@users.noreply.github.com> Co-authored-by: chencangtao <chencangtao@huawei.com>
2026-01-20 21:31:38 +08:00
# torch.compile requires the output of the fx graph to be a tuple
if not graph_returns_tuple(graph):
return make_graph_return_tuple(graph, example_inputs, npugraph_ex), None
return npugraph_ex(graph, example_inputs), None
[Feature] Support npuhraph_ex backend (#4700) ### What this PR does / why we need it? We introduced the npugraph_ex backend through the vllm's adaptor dispatch mechanism to accelerate aclgraph. This solution is based on torch.compile and uses torchair to optimize the fx.graph. The performance gains are mainly obtained from the static kernel. We conducted tests on Qwen3-30B and achieved over 5% performance optimization. ### Does this PR introduce _any_ user-facing change? Yes, we add a new switch named"enable_npugraph_ex" in additional_config, default is False. We also add an example to show how to register custom replacement pass ### More information about this PR This feature depends on the release of CANN and torch_npu in Q4. We tested it on a package that has not been publicly released yet and verified that the functionality works. This feature is still experimental at the moment; setting the config true will directly raise error. Merging into the main branch initially involves some preliminary commits to facilitate subsequent development and testing of the feature, as well as to avoid submitting an excessively large PR at once. - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: chencangtao <chencangtao@huawei.com> Signed-off-by: ChenCangtao <50493711+ChenCangtao@users.noreply.github.com> Co-authored-by: chencangtao <chencangtao@huawei.com> Co-authored-by: panchao-hub <315134829@qq.com> Co-authored-by: wbigat <wbigat@163.com> Co-authored-by: Mengqing Cao <cmq0113@163.com>
2025-12-10 20:48:05 +08:00
Adopt inductor fusion and define quantization fusion pass (#4168) ### What this PR does / why we need it? The main goal of this PR to alleviate the high maintenance burden from model duplication when we are going to do the model optimization. Some of our optimized models diverges a little from the vllm's modeling, but needs to rewrite several part of original one, brings negligible maintenance bruden to the vllm-ascend.In order to solve that, we propose to leverage `torch.compile` and `inductor pattern matcher`, automatically fuse the pattern we want to merge. For more details can refer to the RFC https://github.com/vllm-project/vllm-ascend/issues/4239 This pr integrates `AddRMSNorm` and the `Quant` operator, which can improve the inference speed of models using `w8a8 `quantization. ### Does this PR introduce _any_ user-facing change? Yes, add new additional_config ### How was this patch tested? ```python def main(): prompts = [ "The president of the United States is Mr.", ] # Create a sampling params object. sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95) # Create an LLM. llm = LLM( model="/root/.cache/modelscope/hub/models/vllm-ascend/Qwen3-8B-W8A8", # enforce_eager=True, tensor_parallel_size=1, trust_remote_code=True, gpu_memory_utilization=0.7, quantization="ascend", ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` ```text Prompt: 'The president of the United States is Mr.', Generated text: ' Trump. The president of the United States is Mr. Biden. Which of the following statements is correct? \n\nA. Mr. Trump is Mr. Biden. \nB. Mr. Trump is not Mr. Biden. \nC. The president of the United States is not Mr. Trump. \nD. The president of the United States is not Mr. Biden.\n\nThe question presents a contradiction: it states that "The president of the United States is Mr. Trump" and "The president of' ``` - vLLM version: 86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 - vLLM main: https://github.com/vllm-project/vllm/commit/86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 --------- Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: wxsIcey <1790571317@qq.com>
2025-12-04 10:29:48 +08:00
class AscendCompiler(CompilerInterface):
"""
AscendCompiler is a custom compiler interface for the Ascend platform.
This class provides a method to compile a PyTorch FX graph module with
specific configurations for graph fusion and decomposition.
"""
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
Adopt inductor fusion and define quantization fusion pass (#4168) ### What this PR does / why we need it? The main goal of this PR to alleviate the high maintenance burden from model duplication when we are going to do the model optimization. Some of our optimized models diverges a little from the vllm's modeling, but needs to rewrite several part of original one, brings negligible maintenance bruden to the vllm-ascend.In order to solve that, we propose to leverage `torch.compile` and `inductor pattern matcher`, automatically fuse the pattern we want to merge. For more details can refer to the RFC https://github.com/vllm-project/vllm-ascend/issues/4239 This pr integrates `AddRMSNorm` and the `Quant` operator, which can improve the inference speed of models using `w8a8 `quantization. ### Does this PR introduce _any_ user-facing change? Yes, add new additional_config ### How was this patch tested? ```python def main(): prompts = [ "The president of the United States is Mr.", ] # Create a sampling params object. sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95) # Create an LLM. llm = LLM( model="/root/.cache/modelscope/hub/models/vllm-ascend/Qwen3-8B-W8A8", # enforce_eager=True, tensor_parallel_size=1, trust_remote_code=True, gpu_memory_utilization=0.7, quantization="ascend", ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` ```text Prompt: 'The president of the United States is Mr.', Generated text: ' Trump. The president of the United States is Mr. Biden. Which of the following statements is correct? \n\nA. Mr. Trump is Mr. Biden. \nB. Mr. Trump is not Mr. Biden. \nC. The president of the United States is not Mr. Trump. \nD. The president of the United States is not Mr. Biden.\n\nThe question presents a contradiction: it states that "The president of the United States is Mr. Trump" and "The president of' ``` - vLLM version: 86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 - vLLM main: https://github.com/vllm-project/vllm/commit/86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 --------- Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: wxsIcey <1790571317@qq.com>
2025-12-04 10:29:48 +08:00
name = "AscendCompiler"
[Feature]refactor the npugraph_ex config, support online-infer with static kernel (#5775) ### What this PR does / why we need it? This is a part of https://github.com/vllm-project/vllm-ascend/issues/4715#issue-3694310762 1. refactor the npugraph_ex config,modified the default configuration of the static kernel, new default value of static kernel is false 2. support online-infer with static kernel 3. fixed the issue where manually modifying FX graphs caused an abnormal model return type, and removed the related redundant code. ### Does this PR introduce _any_ user-facing change? yes,the new config of npugraph_ex is as follow: ``` additional_config={ "npugraph_ex_config": { "enable": True, "enable_static_kernel": False } } ``` ### How was this patch tested? ``` vllm serve /data/DeepSeek-V3.1-Terminus-w4a8 \ --host 0.0.0.0 \ --port 8004 \ --data-parallel-size 4 \ --tensor-parallel-size 4 \ --quantization ascend \ --seed 1024 \ --served-model-name deepseek_v3 \ --enable-expert-parallel \ --max-num-seqs 48 \ --max-model-len 40000 \ --async-scheduling \ --max-num-batched-tokens 9000 \ --trust-remote-code \ --no-enable-prefix-caching \ --speculative-config '{"num_speculative_tokens": 3, "method":"deepseek_mtp","disable_padded_drafter_batch": false}' \ --gpu-memory-utilization 0.9 \ --compilation-config '{"cudagraph_capture_sizes":[4,32,64,112,160,176,192], "cudagraph_mode": "FULL_DECODE_ONLY"}' \ --additional-config \ '{"enable_shared_expert_dp": true,"multistream_overlap_shared_expert": true,"npugraph_ex_config":{"enable":true}}' ``` - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2f4e6548efec402b913ffddc8726230d9311948d --------- Signed-off-by: chencangtao <chencangtao@huawei.com> Signed-off-by: ChenCangtao <50493711+ChenCangtao@users.noreply.github.com> Co-authored-by: chencangtao <chencangtao@huawei.com>
2026-01-20 21:31:38 +08:00
def compute_hash(self, vllm_config: VllmConfig) -> str:
npugraph_ex_config = get_ascend_config().npugraph_ex_config
if npugraph_ex_config.enable:
self.vllm_config = vllm_config
return vllm_config.compute_hash()
Adopt inductor fusion and define quantization fusion pass (#4168) ### What this PR does / why we need it? The main goal of this PR to alleviate the high maintenance burden from model duplication when we are going to do the model optimization. Some of our optimized models diverges a little from the vllm's modeling, but needs to rewrite several part of original one, brings negligible maintenance bruden to the vllm-ascend.In order to solve that, we propose to leverage `torch.compile` and `inductor pattern matcher`, automatically fuse the pattern we want to merge. For more details can refer to the RFC https://github.com/vllm-project/vllm-ascend/issues/4239 This pr integrates `AddRMSNorm` and the `Quant` operator, which can improve the inference speed of models using `w8a8 `quantization. ### Does this PR introduce _any_ user-facing change? Yes, add new additional_config ### How was this patch tested? ```python def main(): prompts = [ "The president of the United States is Mr.", ] # Create a sampling params object. sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95) # Create an LLM. llm = LLM( model="/root/.cache/modelscope/hub/models/vllm-ascend/Qwen3-8B-W8A8", # enforce_eager=True, tensor_parallel_size=1, trust_remote_code=True, gpu_memory_utilization=0.7, quantization="ascend", ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` ```text Prompt: 'The president of the United States is Mr.', Generated text: ' Trump. The president of the United States is Mr. Biden. Which of the following statements is correct? \n\nA. Mr. Trump is Mr. Biden. \nB. Mr. Trump is not Mr. Biden. \nC. The president of the United States is not Mr. Trump. \nD. The president of the United States is not Mr. Biden.\n\nThe question presents a contradiction: it states that "The president of the United States is Mr. Trump" and "The president of' ``` - vLLM version: 86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 - vLLM main: https://github.com/vllm-project/vllm/commit/86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 --------- Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: wxsIcey <1790571317@qq.com>
2025-12-04 10:29:48 +08:00
def compile(
self,
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
compile_range: Range,
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
[Feature]refactor the npugraph_ex config, support online-infer with static kernel (#5775) ### What this PR does / why we need it? This is a part of https://github.com/vllm-project/vllm-ascend/issues/4715#issue-3694310762 1. refactor the npugraph_ex config,modified the default configuration of the static kernel, new default value of static kernel is false 2. support online-infer with static kernel 3. fixed the issue where manually modifying FX graphs caused an abnormal model return type, and removed the related redundant code. ### Does this PR introduce _any_ user-facing change? yes,the new config of npugraph_ex is as follow: ``` additional_config={ "npugraph_ex_config": { "enable": True, "enable_static_kernel": False } } ``` ### How was this patch tested? ``` vllm serve /data/DeepSeek-V3.1-Terminus-w4a8 \ --host 0.0.0.0 \ --port 8004 \ --data-parallel-size 4 \ --tensor-parallel-size 4 \ --quantization ascend \ --seed 1024 \ --served-model-name deepseek_v3 \ --enable-expert-parallel \ --max-num-seqs 48 \ --max-model-len 40000 \ --async-scheduling \ --max-num-batched-tokens 9000 \ --trust-remote-code \ --no-enable-prefix-caching \ --speculative-config '{"num_speculative_tokens": 3, "method":"deepseek_mtp","disable_padded_drafter_batch": false}' \ --gpu-memory-utilization 0.9 \ --compilation-config '{"cudagraph_capture_sizes":[4,32,64,112,160,176,192], "cudagraph_mode": "FULL_DECODE_ONLY"}' \ --additional-config \ '{"enable_shared_expert_dp": true,"multistream_overlap_shared_expert": true,"npugraph_ex_config":{"enable":true}}' ``` - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2f4e6548efec402b913ffddc8726230d9311948d --------- Signed-off-by: chencangtao <chencangtao@huawei.com> Signed-off-by: ChenCangtao <50493711+ChenCangtao@users.noreply.github.com> Co-authored-by: chencangtao <chencangtao@huawei.com>
2026-01-20 21:31:38 +08:00
npugraph_ex_config = get_ascend_config().npugraph_ex_config
if npugraph_ex_config.enable:
assert hasattr(self, "vllm_config")
return npugraph_ex_compile(
graph, example_inputs, compiler_config, self.vllm_config, npugraph_ex_config, compile_range, key
)
[Feature] Support npuhraph_ex backend (#4700) ### What this PR does / why we need it? We introduced the npugraph_ex backend through the vllm's adaptor dispatch mechanism to accelerate aclgraph. This solution is based on torch.compile and uses torchair to optimize the fx.graph. The performance gains are mainly obtained from the static kernel. We conducted tests on Qwen3-30B and achieved over 5% performance optimization. ### Does this PR introduce _any_ user-facing change? Yes, we add a new switch named"enable_npugraph_ex" in additional_config, default is False. We also add an example to show how to register custom replacement pass ### More information about this PR This feature depends on the release of CANN and torch_npu in Q4. We tested it on a package that has not been publicly released yet and verified that the functionality works. This feature is still experimental at the moment; setting the config true will directly raise error. Merging into the main branch initially involves some preliminary commits to facilitate subsequent development and testing of the feature, as well as to avoid submitting an excessively large PR at once. - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: chencangtao <chencangtao@huawei.com> Signed-off-by: ChenCangtao <50493711+ChenCangtao@users.noreply.github.com> Co-authored-by: chencangtao <chencangtao@huawei.com> Co-authored-by: panchao-hub <315134829@qq.com> Co-authored-by: wbigat <wbigat@163.com> Co-authored-by: Mengqing Cao <cmq0113@163.com>
2025-12-10 20:48:05 +08:00
else:
[Lint]Style: Convert `vllm-ascend/compilation` to ruff format (#5912) ### What this PR does / why we need it? Convert `vllm-ascend/compilation` to ruff format. ### Does this PR introduce _any_ user-facing change? During this migration, we encountered some **errors** in our CI and testing environments, such as: ``` vllm_ascend/utils.py:653: in <module> def register_ascend_customop(vllm_config: VllmConfig | None = None): ^^^^^^^^^^^^^^^^^ E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType' ``` **1. Root Cause Analysis:** The project uses a common pattern to break circular dependencies: ```python if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None # Placeholder assigned at runtime ``` When Python parses the function definition `def register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts to evaluate the expression `VllmConfig | None`. Since `VllmConfig` is assigned `None` at runtime, the expression effectively becomes `None | None`. In Python, `None` is an instance of `NoneType`. While the `|` operator is implemented for Type objects (classes), it is not supported for `NoneType` instances, leading to the `TypeError` shown above. **2. Solution:** To maintain the modern `|` syntax required by our new linting standards while preserving our dependency management strategy, I have introduced: ```python from __future__ import annotations ``` at the top of the affected files. This enables **Postponed Evaluation of Annotations (PEP 563)**. **3. Impact and Benefits:** - By enabling `annotations`, Python no longer executes the `VllmConfig | None` operation during module load. Instead, it stores the annotation as a string literal, completely avoiding the `None | None` calculation. - We can keep the `VllmConfig = None` placeholders. This ensures that other modules can still import these symbols without triggering an `ImportError`, maintaining a stable dependency graph. - IDEs and static type checkers (MyPy/Pyright) continue to resolve the types correctly. This allows us to use modern syntax without sacrificing type safety or runtime stability. - The only side effect is that `__annotations__` will now return strings instead of type objects. Since this module does not use runtime type enforcement or reflection, this change has zero negative impact on existing functionality. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00
return fusion_pass_compile(graph, example_inputs, compiler_config, compile_range, key)