[Triton][Config] Add muls_add triton kernel and refactor AscendCompilationConfig (#5518)

### What this PR does / why we need it?
Add muls_add triton kernel with related fusion pass. What's more, this
PR refactors `AscendCompilationConfig` and delete `NpugraphExConfig`.

### Does this PR introduce _any_ user-facing change?
None

### How was this patch tested?
CI passed with new added test.


- vLLM version: v0.13.0
- vLLM main:
45c1ca1ca1

---------

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
whx
2026-03-02 17:54:25 +08:00
committed by GitHub
parent 8547520726
commit 16c879cdf7
14 changed files with 290 additions and 98 deletions

View File

@@ -31,7 +31,6 @@ The following table lists additional configuration options available in vLLM Asc
| `finegrained_tp_config` | dict | `{}` | Configuration options for module tensor parallelism | | `finegrained_tp_config` | dict | `{}` | Configuration options for module tensor parallelism |
| `ascend_compilation_config` | dict | `{}` | Configuration options for ascend compilation | | `ascend_compilation_config` | dict | `{}` | Configuration options for ascend compilation |
| `eplb_config` | dict | `{}` | Configuration options for ascend compilation | | `eplb_config` | dict | `{}` | Configuration options for ascend compilation |
| `npugraph_ex_config` | dict | `{}` | Configuration options for Npugraph_ex backend |
| `refresh` | bool | `false` | Whether to refresh global Ascend configuration content. This is usually used by rlhf or ut/e2e test case. | | `refresh` | bool | `false` | Whether to refresh global Ascend configuration content. This is usually used by rlhf or ut/e2e test case. |
| `dump_config_path` | str | `None` | Configuration file path for msprobe dump(eager mode). | | `dump_config_path` | str | `None` | Configuration file path for msprobe dump(eager mode). |
| `enable_async_exponential` | bool | `False` | Whether to enable asynchronous exponential overlap. To enable asynchronous exponential, set this config to True. | | `enable_async_exponential` | bool | `False` | Whether to enable asynchronous exponential overlap. To enable asynchronous exponential, set this config to True. |
@@ -76,9 +75,12 @@ The details of each configuration option are as follows:
| Name | Type | Default | Description | | Name | Type | Default | Description |
| ---- | ---- | ------- | ----------- | | ---- | ---- | ------- | ----------- |
| `enable_npugraph_ex` | bool | `True` | Whether to enable npugraph_ex backend. |
| `enable_static_kernel` | bool | `False` | Whether to enable static kernel. Suitable for scenarios where shape changes are minimal and some time is available for static kernel compilation. |
| `fuse_norm_quant` | bool | `True` | Whether to enable fuse_norm_quant pass. | | `fuse_norm_quant` | bool | `True` | Whether to enable fuse_norm_quant pass. |
| `fuse_qknorm_rope` | bool | `True` | Whether to enable fuse_qknorm_rope pass. If Triton is not in the environment, set it to False. | | `fuse_qknorm_rope` | bool | `True` | Whether to enable fuse_qknorm_rope pass. If Triton is not in the environment, set it to False. |
| `fuse_allreduce_rms` | bool | `False` | Whether to enable fuse_allreduce_rms pass. It's set to False because of conflict with SP. | | `fuse_allreduce_rms` | bool | `False` | Whether to enable fuse_allreduce_rms pass. It's set to False because of conflict with SP. |
| `fuse_muls_add` | bool | `True` | Whether to enable fuse_muls_add pass.|
**eplb_config** **eplb_config**
@@ -91,16 +93,6 @@ The details of each configuration option are as follows:
| `expert_map_record_path` | str | `None` | Save the expert load calculation results to a new expert table in the specified directory.| | `expert_map_record_path` | str | `None` | Save the expert load calculation results to a new expert table in the specified directory.|
| `num_redundant_experts` | int | `0` | Specify redundant experts during initialization. | | `num_redundant_experts` | int | `0` | Specify redundant experts during initialization. |
**npugraph_ex_config**
| Name | Type | Default | Description |
|------------------------| ---- |---------|----------------------------------------------------------------------------------------|
| `enable` | bool | `True` | Whether to enable npugraph_ex backend. |
| `enable_static_kernel` | bool | `False` | Whether to enable static kernel. Suitable for scenarios where shape changes are minimal and some time is available for static kernel compilation. |
| `fuse_norm_quant` | bool | `True` | Whether to enable fuse_norm_quant pass. |
| `fuse_qknorm_rope` | bool | `True` | Whether to enable fuse_qknorm_rope pass. If Triton is not in the environment, set it to False. |
| `fuse_allreduce_rms` | bool | `False` | Whether to enable fuse_allreduce_rms pass. It's set to False because of conflict with SP. |
### Example ### Example
An example of additional configuration is as follows: An example of additional configuration is as follows:

View File

@@ -16,8 +16,8 @@ from vllm import LLM
model = LLM( model = LLM(
model="path/to/Qwen2-7B-Instruct", model="path/to/Qwen2-7B-Instruct",
additional_config={ additional_config={
"npugraph_ex_config": { "ascend_compilation_config": {
"enable": True, "enable_npugraph_ex": True,
"enable_static_kernel": False, "enable_static_kernel": False,
} }
} }
@@ -29,7 +29,7 @@ Online example:
```shell ```shell
vllm serve Qwen/Qwen2-7B-Instruct vllm serve Qwen/Qwen2-7B-Instruct
--additional-config '{"npugraph_ex_config":{"enable":true, "enable_static_kernel":false}}' --additional-config '{"ascend_compilation_config":{"enable_npugraph_ex":true, "enable_static_kernel":false}}'
``` ```
You can find more details about npugraph_ex [here](https://www.hiascend.com/document/detail/zh/Pytorch/730/modthirdparty/torchairuseguide/torchair_00021.html) You can find more details about npugraph_ex [here](https://www.hiascend.com/document/detail/zh/Pytorch/730/modthirdparty/torchairuseguide/torchair_00021.html)

View File

@@ -28,7 +28,7 @@ def test_qwen3_vl_sp_tp2(model: str) -> None:
"cudagraph_mode": "FULL_DECODE_ONLY", "cudagraph_mode": "FULL_DECODE_ONLY",
"pass_config": {"enable_sp": False} "pass_config": {"enable_sp": False}
}, },
additional_config={"npugraph_ex_config": {"enable": False}} additional_config={"ascend_compilation_config": {"enable_npugraph_ex": False}}
) as runner: ) as runner:
no_sp_outputs = runner.model.generate(prompts, sampling_params) no_sp_outputs = runner.model.generate(prompts, sampling_params)
@@ -41,7 +41,7 @@ def test_qwen3_vl_sp_tp2(model: str) -> None:
"cudagraph_mode": "FULL_DECODE_ONLY", "cudagraph_mode": "FULL_DECODE_ONLY",
"pass_config": {"enable_sp": True} "pass_config": {"enable_sp": True}
}, },
additional_config={"sp_threshold": 10, "npugraph_ex_config": {"enable": False}} additional_config={"sp_threshold": 10, "ascend_compilation_config": {"enable_npugraph_ex": False}}
) as runner: ) as runner:
sp_outputs = runner.model.generate( sp_outputs = runner.model.generate(
prompts, sampling_params) prompts, sampling_params)

View File

@@ -0,0 +1,34 @@
import pytest
import torch
from vllm_ascend.ops.triton.muls_add import muls_add_triton
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
@pytest.mark.parametrize(
("shape", "dtype", "scale"),
[
((1, 2048), torch.float16, 1.25),
((4000, 2048), torch.float16, 0.75),
((4, 2048), torch.bfloat16, 1.0),
],
)
@torch.inference_mode()
def test_muls_add_triton_correctness(shape, dtype, scale):
"""compare the correctness of muls_add_triton with the PyTorch baseline implementation."""
init_device_properties_triton()
device = "npu"
torch.manual_seed(0)
x = torch.randn(*shape, dtype=dtype, device=device)
y = torch.randn(*shape, dtype=dtype, device=device)
out_triton = muls_add_triton(x, y, scale)
out_ref = x * scale + y
rtol, atol = 1e-3, 1e-3
assert out_triton.shape == out_ref.shape
assert out_triton.dtype == out_ref.dtype
assert torch.allclose(out_triton, out_ref, rtol=rtol, atol=atol)

View File

@@ -153,7 +153,7 @@ def test_full_decode_only_res_consistency(cur_case: LLMTestCase, monkeypatch):
"max_model_len": 1024, "max_model_len": 1024,
"compilation_config": {"cudagraph_capture_sizes": [4, 8, 32, 64], "cudagraph_mode": "FULL_DECODE_ONLY"}, "compilation_config": {"cudagraph_capture_sizes": [4, 8, 32, 64], "cudagraph_mode": "FULL_DECODE_ONLY"},
"quantization": cur_case.quantization, "quantization": cur_case.quantization,
"additional_config": {"npugraph_ex_config": {"enable": False}}, "additional_config": {"ascend_compilation_config": {"enable_npugraph_ex": False}},
} }
gen_and_valid( gen_and_valid(
runner_kwargs=runner_kwargs, runner_kwargs=runner_kwargs,
@@ -171,7 +171,7 @@ def test_npugraph_ex_res_consistency(cur_case: LLMTestCase, monkeypatch):
"quantization": cur_case.quantization, "quantization": cur_case.quantization,
"max_model_len": 1024, "max_model_len": 1024,
"compilation_config": {"cudagraph_capture_sizes": [4, 8, 32, 64], "cudagraph_mode": "FULL_DECODE_ONLY"}, "compilation_config": {"cudagraph_capture_sizes": [4, 8, 32, 64], "cudagraph_mode": "FULL_DECODE_ONLY"},
"additional_config": {"npugraph_ex_config": {"enable": True}}, "additional_config": {"ascend_compilation_config": {"enable_npugraph_ex": True}},
} }
gen_and_valid( gen_and_valid(
runner_kwargs=runner_kwargs, runner_kwargs=runner_kwargs,
@@ -193,8 +193,8 @@ def test_npugraph_ex_with_static_kernel(cur_case: LLMTestCase, monkeypatch):
"max_model_len": 1024, "max_model_len": 1024,
"compilation_config": {"cudagraph_capture_sizes": [4, 8], "cudagraph_mode": "FULL_DECODE_ONLY"}, "compilation_config": {"cudagraph_capture_sizes": [4, 8], "cudagraph_mode": "FULL_DECODE_ONLY"},
"additional_config": { "additional_config": {
"npugraph_ex_config": { "ascend_compilation_config": {
"enable": True, "enable_npugraph_ex": True,
"enable_static_kernel": True, "enable_static_kernel": True,
} }
}, },

View File

@@ -66,13 +66,11 @@ class TestAscendConfig(TestBase):
self.assertEqual(ascend_config.eplb_config.num_redundant_experts, 2) self.assertEqual(ascend_config.eplb_config.num_redundant_experts, 2)
self.assertTrue(ascend_config.multistream_overlap_shared_expert) self.assertTrue(ascend_config.multistream_overlap_shared_expert)
npugraph_ex_config = ascend_config.npugraph_ex_config
self.assertTrue(npugraph_ex_config.enable)
self.assertFalse(npugraph_ex_config.enable_static_kernel)
ascend_compilation_config = ascend_config.ascend_compilation_config ascend_compilation_config = ascend_config.ascend_compilation_config
self.assertFalse(ascend_compilation_config.fuse_norm_quant) self.assertFalse(ascend_compilation_config.fuse_norm_quant)
self.assertFalse(ascend_config.enable_kv_nz) self.assertFalse(ascend_config.enable_kv_nz)
self.assertTrue(ascend_compilation_config.enable_npugraph_ex)
self.assertFalse(ascend_compilation_config.enable_static_kernel)
ascend_fusion_config = ascend_config.ascend_fusion_config ascend_fusion_config = ascend_config.ascend_fusion_config
self.assertFalse(ascend_fusion_config.fusion_ops_gmmswigluquant) self.assertFalse(ascend_fusion_config.fusion_ops_gmmswigluquant)
@@ -82,16 +80,16 @@ class TestAscendConfig(TestBase):
def test_init_ascend_config_enable_npugraph_ex(self, mock_fix_incompatible_config): def test_init_ascend_config_enable_npugraph_ex(self, mock_fix_incompatible_config):
test_vllm_config = VllmConfig() test_vllm_config = VllmConfig()
test_vllm_config.additional_config = { test_vllm_config.additional_config = {
"npugraph_ex_config": { "ascend_compilation_config": {
"enable": True, "enable_npugraph_ex": True,
"enable_static_kernel": True "enable_static_kernel": True
}, },
"refresh": True "refresh": True
} }
npugraph_ex_config = init_ascend_config( ascend_compilation_config = init_ascend_config(
test_vllm_config).npugraph_ex_config test_vllm_config).ascend_compilation_config
self.assertTrue(npugraph_ex_config.enable) self.assertTrue(ascend_compilation_config.enable_npugraph_ex)
self.assertTrue(npugraph_ex_config.enable_static_kernel) self.assertTrue(ascend_compilation_config.enable_static_kernel)
@_clean_up_ascend_config @_clean_up_ascend_config
@patch("vllm_ascend.platform.NPUPlatform._fix_incompatible_config") @patch("vllm_ascend.platform.NPUPlatform._fix_incompatible_config")

View File

@@ -118,8 +118,6 @@ class AscendConfig:
from vllm_ascend.utils import get_flashcomm2_config_and_validate from vllm_ascend.utils import get_flashcomm2_config_and_validate
self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_config_and_validate(self, vllm_config) self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_config_and_validate(self, vllm_config)
npugraph_ex_config = additional_config.get("npugraph_ex_config", {})
self.npugraph_ex_config = NpugraphExConfig(**npugraph_ex_config)
# We find that _npu_paged_attention still performs better than # We find that _npu_paged_attention still performs better than
# npu_fused_infer_attention_score in some cases. We allow to execute # npu_fused_infer_attention_score in some cases. We allow to execute
# _npu_paged_attention in this cases. This should be removed once # _npu_paged_attention in this cases. This should be removed once
@@ -163,8 +161,8 @@ class AscendConfig:
def update_compile_ranges_split_points(self): def update_compile_ranges_split_points(self):
vllm_config = self.vllm_config vllm_config = self.vllm_config
if self.npugraph_ex_config.enable: if self.ascend_compilation_config.enable_npugraph_ex:
if self.npugraph_ex_config.fuse_allreduce_rms: if self.ascend_compilation_config.fuse_allreduce_rms:
from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THRESHOLD from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THRESHOLD
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
@@ -253,56 +251,9 @@ class AscendCompilationConfig:
deployed on Ascend platforms. deployed on Ascend platforms.
""" """
def __init__(
self, fuse_norm_quant: bool = True, fuse_qknorm_rope: bool = True, fuse_allreduce_rms: bool = False, **kwargs
):
"""
Initialize the configuration.
Args:
fuse_norm_quant (bool): Whether to enable norm and quant fusion optimization.
When set to True, the system will optimize norm and quant operations.
Default: True
fuse_qknorm_rope (bool): Whether to enable qknorm and rope fusion optimization.
Default: True
fuse_allreduce_rms (bool): Whether to enable allreduce and addrmsnorm fusion optimization.
Default: False
**kwargs: Additional optional parameters for forward compatibility and configuration extension.
"""
self.fuse_norm_quant = fuse_norm_quant
self.fuse_qknorm_rope = fuse_qknorm_rope
self.fuse_allreduce_rms = fuse_allreduce_rms
class AscendFusionConfig:
"""
Configuration for controlling whether to use a fused operator gmmswigluquant.
"""
def __init__(self, fusion_ops_gmmswigluquant: bool = True, **kwargs):
"""
Initialize the configuration.
Args:
fusion_ops_gmmswigluquant (bool): Whether to use a fused operator gmmswigluquant.
When set to True, the system will use a fused operator gmmswigluquant.
Default: True
**kwargs: Additional optional parameters for forward compatibility and configuration extension.
"""
self.fusion_ops_gmmswigluquant = fusion_ops_gmmswigluquant
class NpugraphExConfig:
"""
Configuration for controlling the behavior of npugraph_ex backend.
This class provides a way to configure whether to use the npugraph_ex backend and static kernel.
These configurations can directly impact the performance and behavior of models deployed on Ascend platforms.
"""
def __init__( def __init__(
self, self,
enable: bool = True, enable_npugraph_ex: bool = True,
enable_static_kernel: bool = False, enable_static_kernel: bool = False,
fuse_norm_quant: bool = True, fuse_norm_quant: bool = True,
fuse_qknorm_rope: bool = True, fuse_qknorm_rope: bool = True,
@@ -313,7 +264,7 @@ class NpugraphExConfig:
Initialize the configuration. Initialize the configuration.
Args: Args:
enable (bool): Whether to enable npugraph_ex backend. enable_npugraph_ex (bool): Whether to enable npugraph_ex backend.
When set to True, the Fx graph generated by Dymano will be When set to True, the Fx graph generated by Dymano will be
optimized and compiled by the npugraph_ex backend. optimized and compiled by the npugraph_ex backend.
Default: True Default: True
@@ -333,11 +284,32 @@ class NpugraphExConfig:
Default: False Default: False
**kwargs: Additional optional parameters for forward compatibility and configuration extension. **kwargs: Additional optional parameters for forward compatibility and configuration extension.
""" """
self.enable = enable
self.enable_static_kernel = enable_static_kernel
self.fuse_norm_quant = fuse_norm_quant self.fuse_norm_quant = fuse_norm_quant
self.fuse_qknorm_rope = fuse_qknorm_rope self.fuse_qknorm_rope = fuse_qknorm_rope
self.fuse_allreduce_rms = fuse_allreduce_rms self.fuse_allreduce_rms = fuse_allreduce_rms
self.enable_npugraph_ex = enable_npugraph_ex
self.enable_static_kernel = enable_static_kernel
self.fuse_muls_add = kwargs.get("fuse_muls_add", True)
if self.enable_static_kernel:
assert self.enable_npugraph_ex, "Static kernel generation requires npugraph_ex to be enabled."
class AscendFusionConfig:
"""
Configuration for controlling whether to use a fused operator gmmswigluquant.
"""
def __init__(self, fusion_ops_gmmswigluquant: bool = True, **kwargs):
"""
Initialize the configuration.
Args:
fusion_ops_gmmswigluquant (bool): Whether to use a fused operator gmmswigluquant.
When set to True, the system will use a fused operator gmmswigluquant.
Default: True
**kwargs: Additional optional parameters for forward compatibility and configuration extension.
"""
self.fusion_ops_gmmswigluquant = fusion_ops_gmmswigluquant
class XliteGraphConfig: class XliteGraphConfig:

View File

@@ -30,7 +30,7 @@ from vllm.compilation.compiler_interface import CompilerInterface
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.utils import Range from vllm.config.utils import Range
from vllm_ascend.ascend_config import NpugraphExConfig, get_ascend_config from vllm_ascend.ascend_config import AscendCompilationConfig, get_ascend_config
from vllm_ascend.utils import COMPILATION_PASS_KEY from vllm_ascend.utils import COMPILATION_PASS_KEY
@@ -71,7 +71,7 @@ def npugraph_ex_compile(
example_inputs: list[Any], example_inputs: list[Any],
compiler_config: dict[str, Any], compiler_config: dict[str, Any],
vllm_config: VllmConfig, vllm_config: VllmConfig,
npugraph_ex_config: NpugraphExConfig, ascend_compilation_config: AscendCompilationConfig,
compile_range: Range, compile_range: Range,
key: str | None = None, key: str | None = None,
) -> tuple[Callable | None, Any | None]: ) -> tuple[Callable | None, Any | None]:
@@ -83,7 +83,7 @@ def npugraph_ex_compile(
config.mode = "reduce-overhead" config.mode = "reduce-overhead"
# execute FX graph in eager mode before graph mode to optimize FX graph. # execute FX graph in eager mode before graph mode to optimize FX graph.
config.debug.run_eagerly = True config.debug.run_eagerly = True
if npugraph_ex_config.enable_static_kernel: if ascend_compilation_config.enable_static_kernel:
config.experimental_config.aclgraph._aclnn_static_shape_kernel = True config.experimental_config.aclgraph._aclnn_static_shape_kernel = True
# According to the cudagraph_capture_size configuration, set the shapes # According to the cudagraph_capture_size configuration, set the shapes
# that can trigger the compilation of static kernel. If this configuration is # that can trigger the compilation of static kernel. If this configuration is
@@ -117,8 +117,8 @@ class AscendCompiler(CompilerInterface):
name = "AscendCompiler" name = "AscendCompiler"
def compute_hash(self, vllm_config: VllmConfig) -> str: def compute_hash(self, vllm_config: VllmConfig) -> str:
npugraph_ex_config = get_ascend_config().npugraph_ex_config npugraph_ex_enabled = get_ascend_config().ascend_compilation_config.enable_npugraph_ex
if npugraph_ex_config.enable: if npugraph_ex_enabled:
self.vllm_config = vllm_config self.vllm_config = vllm_config
return vllm_config.compute_hash() return vllm_config.compute_hash()
@@ -134,11 +134,11 @@ class AscendCompiler(CompilerInterface):
# see https://github.com/pytorch/pytorch/issues/138980 # see https://github.com/pytorch/pytorch/issues/138980
graph = copy.deepcopy(graph) graph = copy.deepcopy(graph)
npugraph_ex_config = get_ascend_config().npugraph_ex_config ascend_compilation_config = get_ascend_config().ascend_compilation_config
if npugraph_ex_config.enable: if ascend_compilation_config.enable_npugraph_ex:
assert hasattr(self, "vllm_config") assert hasattr(self, "vllm_config")
return npugraph_ex_compile( return npugraph_ex_compile(
graph, example_inputs, compiler_config, self.vllm_config, npugraph_ex_config, compile_range, key graph, example_inputs, compiler_config, self.vllm_config, ascend_compilation_config, compile_range, key
) )
else: else:
return fusion_pass_compile(graph, example_inputs, compiler_config, compile_range, key) return fusion_pass_compile(graph, example_inputs, compiler_config, compile_range, key)

View File

@@ -64,6 +64,11 @@ class GraphFusionPassManager:
self.passes.append(MatmulAllReduceAddRMSNormPass(config)) self.passes.append(MatmulAllReduceAddRMSNormPass(config))
if self.ascend_compilation_config.get("fuse_muls_add", True):
from .passes.muls_add_pass import MulsAddFusionPass
self.passes.append(MulsAddFusionPass(config))
if config.compilation_config.pass_config.enable_sp: if config.compilation_config.pass_config.enable_sp:
from .passes.sequence_parallelism import AscendSequenceParallelismPass from .passes.sequence_parallelism import AscendSequenceParallelismPass

View File

@@ -0,0 +1,117 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import annotations
import torch
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.config import VllmConfig
from vllm.config.compilation import Range
from vllm.logger import logger
from vllm_ascend.compilation.passes.base_pattern import BasePattern
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.15.0"):
from vllm.compilation.vllm_inductor_pass import VllmInductorPass # type: ignore
else:
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
class MulsAddPattern(BasePattern):
"""
Pattern that matches an element-wise mul + add sequence:
tmp = x * scale
out = tmp + y
and replaces it with a call to the muls_add_triton kernel.
"""
def __init__(self, vllm_config: VllmConfig, scale: float = 1.0):
super().__init__(vllm_config)
self.scale = scale
def get_inputs(self) -> list[torch.Tensor]:
"""
Generate example inputs for the MulsAddPattern.
The exact shapes are not important for pattern matching; they only
provide meta information for the pattern matcher.
"""
x = torch.randn(2, 2048, device="npu", dtype=self.dtype)
y = torch.randn(2, 2048, device="npu", dtype=self.dtype)
# Only tensor inputs are needed here. The scalar scale is stored on the
# pattern instance (self.scale) instead of being passed as an input.
return [x, y]
def get_pattern(self):
def pattern(x: torch.Tensor, y: torch.Tensor):
"""
Pattern for element-wise x * scale + y.
"""
tmp = x * self.scale
out = tmp + y
return out
return pattern
def get_replacement(self):
def replacement(x: torch.Tensor, y: torch.Tensor):
"""
Replacement that calls the muls_add_triton kernel using the
class-level scalar self.scale.
"""
return torch.ops.vllm.muls_add(x, y, self.scale)
return replacement
class MulsAddFusionPass(VllmInductorPass):
"""
A fusion pass that replaces simple element-wise x * scale + y patterns
with the Triton-based muls_add_triton kernel on Ascend.
"""
def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)
self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(pass_name="muls_add_fusion_pass")
# For now we enable this pass for all floating-point dtypes that the
# model is configured to use.
dtype = vllm_config.model_config.dtype
if dtype not in (torch.float16, torch.bfloat16, torch.float32):
logger.debug("MulsAdd fusion not enabled: unsupported dtype %s", dtype)
return
# Currently we only register a single pattern instance with a fixed
# scalar scale value. If needed, multiple instances with different
# scales can be added here in the future.
MulsAddPattern(vllm_config, scale=1.0).register(self.pattern_match_passes)
def __call__(self, graph: torch.fx.Graph) -> None: # type: ignore[override]
self.begin()
self.matched_count = self.pattern_match_passes.apply(graph)
logger.debug("Fused %s muls_add patterns", self.matched_count)
self.end_and_log()
def is_applicable_for_range(self, compile_range: Range) -> bool:
"""
Check if the pass is applicable for the current configuration.
For now, muls_add fusion is always allowed for the selected ranges.
This hook exists so that we can add more fine-grained range control
in the future if needed.
"""
return True

View File

@@ -15,6 +15,7 @@ from vllm.utils.torch_utils import direct_register_custom_op
from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ops.rotary_embedding import rope_forward_oot from vllm_ascend.ops.rotary_embedding import rope_forward_oot
from vllm_ascend.ops.triton.muls_add import muls_add_triton
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.utils import npu_stream_switch, prefetch_stream from vllm_ascend.utils import npu_stream_switch, prefetch_stream
@@ -201,6 +202,14 @@ def _rope_forward_oot_impl_fake(
return query, key return query, key
def _muls_add_impl_fake(
x: torch.Tensor,
y: torch.Tensor,
scale: float,
) -> torch.Tensor:
return torch.empty_like(x)
direct_register_custom_op( direct_register_custom_op(
op_name="maybe_chunk_residual", op_name="maybe_chunk_residual",
op_func=_maybe_chunk_residual_impl, op_func=_maybe_chunk_residual_impl,
@@ -272,3 +281,11 @@ direct_register_custom_op(
mutates_args=[], mutates_args=[],
dispatch_key="PrivateUse1", dispatch_key="PrivateUse1",
) )
direct_register_custom_op(
op_name="muls_add",
op_func=muls_add_triton,
fake_impl=_muls_add_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1",
)

View File

@@ -0,0 +1,57 @@
import torch
from vllm.triton_utils import tl, triton
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
@triton.jit
def muls_add_kernel(
x_ptr, # *Pointer* to first input vector.
y_ptr, # *Pointer* to second input vector.
output_ptr, # *Pointer* to output vector.
scale, # Scale factor.
n_elements, # Size of the vector.
n_blocks, # Total number of blocks.
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
):
pid = tl.program_id(axis=0)
num_programs = tl.num_programs(axis=0)
for block_id in range(pid, n_blocks, num_programs):
block_start = block_id * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x * scale + y
tl.store(output_ptr + offsets, output, mask=mask)
def muls_add_triton(x: torch.Tensor, y: torch.Tensor, scale: float) -> torch.Tensor:
assert x.shape == y.shape, "Input tensors must have the same shape."
hidden_size = x.shape[-1]
n_elements = x.numel()
output = torch.empty_like(x)
# Determine the number of vector cores available
num_cores = get_vectorcore_num()
# Define block size
BLOCK_SIZE = max(hidden_size // 2, 1024)
# Calculate the number of programs to launch
num_blocks = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE
num_programs = min(num_blocks, num_cores)
# Launch the Triton kernel
muls_add_kernel[(num_programs,)](
x,
y,
output,
scale,
n_elements,
num_blocks,
BLOCK_SIZE=BLOCK_SIZE,
)
return output

View File

@@ -279,7 +279,7 @@ class NPUPlatform(Platform):
if compilation_config.cudagraph_mode == CUDAGraphMode.NONE: if compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
compilation_config.mode = CompilationMode.NONE compilation_config.mode = CompilationMode.NONE
ascend_config.npugraph_ex_config.enable = False ascend_config.ascend_compilation_config.enable_npugraph_ex = False
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE: elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE:
logger.info("PIECEWISE compilation enabled on NPU. use_inductor not supported - using only ACL Graph mode") logger.info("PIECEWISE compilation enabled on NPU. use_inductor not supported - using only ACL Graph mode")
assert compilation_config.mode == CompilationMode.VLLM_COMPILE, ( assert compilation_config.mode == CompilationMode.VLLM_COMPILE, (
@@ -299,7 +299,7 @@ class NPUPlatform(Platform):
# not be detected in advance assert. # not be detected in advance assert.
compilation_config.splitting_ops.extend(["vllm::mla_forward"]) compilation_config.splitting_ops.extend(["vllm::mla_forward"])
update_aclgraph_sizes(vllm_config) update_aclgraph_sizes(vllm_config)
ascend_config.npugraph_ex_config.enable = False ascend_config.ascend_compilation_config.enable_npugraph_ex = False
elif ( elif (
compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY
or compilation_config.cudagraph_mode == CUDAGraphMode.FULL or compilation_config.cudagraph_mode == CUDAGraphMode.FULL
@@ -328,7 +328,7 @@ class NPUPlatform(Platform):
) )
compilation_config.cudagraph_mode = CUDAGraphMode.NONE compilation_config.cudagraph_mode = CUDAGraphMode.NONE
compilation_config.mode = CompilationMode.NONE compilation_config.mode = CompilationMode.NONE
ascend_config.npugraph_ex_config.enable = False ascend_config.ascend_compilation_config.enable_npugraph_ex = False
# TODO: Remove this check when ACL Graph supports ASCEND_LAUNCH_BLOCKING=1 # TODO: Remove this check when ACL Graph supports ASCEND_LAUNCH_BLOCKING=1
# Then, we will have to discuss the error handling strategy and user experience # Then, we will have to discuss the error handling strategy and user experience

View File

@@ -138,8 +138,8 @@ class NPUWorker(WorkerBase):
self.use_v2_model_runner = envs_vllm.VLLM_USE_V2_MODEL_RUNNER self.use_v2_model_runner = envs_vllm.VLLM_USE_V2_MODEL_RUNNER
npugraph_ex_config = get_ascend_config().npugraph_ex_config ascend_compilation_config = get_ascend_config().ascend_compilation_config
if npugraph_ex_config.enable and npugraph_ex_config.enable_static_kernel: if ascend_compilation_config.enable_npugraph_ex and ascend_compilation_config.enable_static_kernel:
# Prevent duplicate triggers, execute the exit logic only once # Prevent duplicate triggers, execute the exit logic only once
shutdown_request = False shutdown_request = False