diff --git a/docs/source/user_guide/configuration/additional_config.md b/docs/source/user_guide/configuration/additional_config.md index 6396eb05..f655c49f 100644 --- a/docs/source/user_guide/configuration/additional_config.md +++ b/docs/source/user_guide/configuration/additional_config.md @@ -31,7 +31,6 @@ The following table lists additional configuration options available in vLLM Asc | `finegrained_tp_config` | dict | `{}` | Configuration options for module tensor parallelism | | `ascend_compilation_config` | dict | `{}` | Configuration options for ascend compilation | | `eplb_config` | dict | `{}` | Configuration options for ascend compilation | -| `npugraph_ex_config` | dict | `{}` | Configuration options for Npugraph_ex backend | | `refresh` | bool | `false` | Whether to refresh global Ascend configuration content. This is usually used by rlhf or ut/e2e test case. | | `dump_config_path` | str | `None` | Configuration file path for msprobe dump(eager mode). | | `enable_async_exponential` | bool | `False` | Whether to enable asynchronous exponential overlap. To enable asynchronous exponential, set this config to True. | @@ -76,9 +75,12 @@ The details of each configuration option are as follows: | Name | Type | Default | Description | | ---- | ---- | ------- | ----------- | +| `enable_npugraph_ex` | bool | `True` | Whether to enable npugraph_ex backend. | +| `enable_static_kernel` | bool | `False` | Whether to enable static kernel. Suitable for scenarios where shape changes are minimal and some time is available for static kernel compilation. | | `fuse_norm_quant` | bool | `True` | Whether to enable fuse_norm_quant pass. | | `fuse_qknorm_rope` | bool | `True` | Whether to enable fuse_qknorm_rope pass. If Triton is not in the environment, set it to False. | | `fuse_allreduce_rms` | bool | `False` | Whether to enable fuse_allreduce_rms pass. It's set to False because of conflict with SP. | +| `fuse_muls_add` | bool | `True` | Whether to enable fuse_muls_add pass.| **eplb_config** @@ -91,16 +93,6 @@ The details of each configuration option are as follows: | `expert_map_record_path` | str | `None` | Save the expert load calculation results to a new expert table in the specified directory.| | `num_redundant_experts` | int | `0` | Specify redundant experts during initialization. | -**npugraph_ex_config** - -| Name | Type | Default | Description | -|------------------------| ---- |---------|----------------------------------------------------------------------------------------| -| `enable` | bool | `True` | Whether to enable npugraph_ex backend. | -| `enable_static_kernel` | bool | `False` | Whether to enable static kernel. Suitable for scenarios where shape changes are minimal and some time is available for static kernel compilation. | -| `fuse_norm_quant` | bool | `True` | Whether to enable fuse_norm_quant pass. | -| `fuse_qknorm_rope` | bool | `True` | Whether to enable fuse_qknorm_rope pass. If Triton is not in the environment, set it to False. | -| `fuse_allreduce_rms` | bool | `False` | Whether to enable fuse_allreduce_rms pass. It's set to False because of conflict with SP. | - ### Example An example of additional configuration is as follows: diff --git a/docs/source/user_guide/feature_guide/npugraph_ex.md b/docs/source/user_guide/feature_guide/npugraph_ex.md index fa318e9a..dc05284e 100644 --- a/docs/source/user_guide/feature_guide/npugraph_ex.md +++ b/docs/source/user_guide/feature_guide/npugraph_ex.md @@ -16,8 +16,8 @@ from vllm import LLM model = LLM( model="path/to/Qwen2-7B-Instruct", additional_config={ - "npugraph_ex_config": { - "enable": True, + "ascend_compilation_config": { + "enable_npugraph_ex": True, "enable_static_kernel": False, } } @@ -29,7 +29,7 @@ Online example: ```shell vllm serve Qwen/Qwen2-7B-Instruct ---additional-config '{"npugraph_ex_config":{"enable":true, "enable_static_kernel":false}}' +--additional-config '{"ascend_compilation_config":{"enable_npugraph_ex":true, "enable_static_kernel":false}}' ``` You can find more details about npugraph_ex [here](https://www.hiascend.com/document/detail/zh/Pytorch/730/modthirdparty/torchairuseguide/torchair_00021.html) diff --git a/tests/e2e/multicard/2-cards/test_sp_pass.py b/tests/e2e/multicard/2-cards/test_sp_pass.py index f5ac722f..c3e99dfc 100644 --- a/tests/e2e/multicard/2-cards/test_sp_pass.py +++ b/tests/e2e/multicard/2-cards/test_sp_pass.py @@ -28,7 +28,7 @@ def test_qwen3_vl_sp_tp2(model: str) -> None: "cudagraph_mode": "FULL_DECODE_ONLY", "pass_config": {"enable_sp": False} }, - additional_config={"npugraph_ex_config": {"enable": False}} + additional_config={"ascend_compilation_config": {"enable_npugraph_ex": False}} ) as runner: no_sp_outputs = runner.model.generate(prompts, sampling_params) @@ -41,7 +41,7 @@ def test_qwen3_vl_sp_tp2(model: str) -> None: "cudagraph_mode": "FULL_DECODE_ONLY", "pass_config": {"enable_sp": True} }, - additional_config={"sp_threshold": 10, "npugraph_ex_config": {"enable": False}} + additional_config={"sp_threshold": 10, "ascend_compilation_config": {"enable_npugraph_ex": False}} ) as runner: sp_outputs = runner.model.generate( prompts, sampling_params) diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_muls_add.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_muls_add.py new file mode 100644 index 00000000..ac1af2c3 --- /dev/null +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_muls_add.py @@ -0,0 +1,34 @@ +import pytest +import torch + +from vllm_ascend.ops.triton.muls_add import muls_add_triton +from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton + + +@pytest.mark.parametrize( + ("shape", "dtype", "scale"), + [ + ((1, 2048), torch.float16, 1.25), + ((4000, 2048), torch.float16, 0.75), + ((4, 2048), torch.bfloat16, 1.0), + ], +) +@torch.inference_mode() +def test_muls_add_triton_correctness(shape, dtype, scale): + """compare the correctness of muls_add_triton with the PyTorch baseline implementation.""" + init_device_properties_triton() + device = "npu" + + torch.manual_seed(0) + x = torch.randn(*shape, dtype=dtype, device=device) + y = torch.randn(*shape, dtype=dtype, device=device) + + out_triton = muls_add_triton(x, y, scale) + out_ref = x * scale + y + + rtol, atol = 1e-3, 1e-3 + + assert out_triton.shape == out_ref.shape + assert out_triton.dtype == out_ref.dtype + assert torch.allclose(out_triton, out_ref, rtol=rtol, atol=atol) + diff --git a/tests/e2e/singlecard/test_aclgraph_accuracy.py b/tests/e2e/singlecard/test_aclgraph_accuracy.py index e031e93f..6835b194 100644 --- a/tests/e2e/singlecard/test_aclgraph_accuracy.py +++ b/tests/e2e/singlecard/test_aclgraph_accuracy.py @@ -153,7 +153,7 @@ def test_full_decode_only_res_consistency(cur_case: LLMTestCase, monkeypatch): "max_model_len": 1024, "compilation_config": {"cudagraph_capture_sizes": [4, 8, 32, 64], "cudagraph_mode": "FULL_DECODE_ONLY"}, "quantization": cur_case.quantization, - "additional_config": {"npugraph_ex_config": {"enable": False}}, + "additional_config": {"ascend_compilation_config": {"enable_npugraph_ex": False}}, } gen_and_valid( runner_kwargs=runner_kwargs, @@ -171,7 +171,7 @@ def test_npugraph_ex_res_consistency(cur_case: LLMTestCase, monkeypatch): "quantization": cur_case.quantization, "max_model_len": 1024, "compilation_config": {"cudagraph_capture_sizes": [4, 8, 32, 64], "cudagraph_mode": "FULL_DECODE_ONLY"}, - "additional_config": {"npugraph_ex_config": {"enable": True}}, + "additional_config": {"ascend_compilation_config": {"enable_npugraph_ex": True}}, } gen_and_valid( runner_kwargs=runner_kwargs, @@ -193,8 +193,8 @@ def test_npugraph_ex_with_static_kernel(cur_case: LLMTestCase, monkeypatch): "max_model_len": 1024, "compilation_config": {"cudagraph_capture_sizes": [4, 8], "cudagraph_mode": "FULL_DECODE_ONLY"}, "additional_config": { - "npugraph_ex_config": { - "enable": True, + "ascend_compilation_config": { + "enable_npugraph_ex": True, "enable_static_kernel": True, } }, diff --git a/tests/ut/test_ascend_config.py b/tests/ut/test_ascend_config.py index c3c42d4d..c510090f 100644 --- a/tests/ut/test_ascend_config.py +++ b/tests/ut/test_ascend_config.py @@ -66,13 +66,11 @@ class TestAscendConfig(TestBase): self.assertEqual(ascend_config.eplb_config.num_redundant_experts, 2) self.assertTrue(ascend_config.multistream_overlap_shared_expert) - npugraph_ex_config = ascend_config.npugraph_ex_config - self.assertTrue(npugraph_ex_config.enable) - self.assertFalse(npugraph_ex_config.enable_static_kernel) - ascend_compilation_config = ascend_config.ascend_compilation_config self.assertFalse(ascend_compilation_config.fuse_norm_quant) self.assertFalse(ascend_config.enable_kv_nz) + self.assertTrue(ascend_compilation_config.enable_npugraph_ex) + self.assertFalse(ascend_compilation_config.enable_static_kernel) ascend_fusion_config = ascend_config.ascend_fusion_config self.assertFalse(ascend_fusion_config.fusion_ops_gmmswigluquant) @@ -82,16 +80,16 @@ class TestAscendConfig(TestBase): def test_init_ascend_config_enable_npugraph_ex(self, mock_fix_incompatible_config): test_vllm_config = VllmConfig() test_vllm_config.additional_config = { - "npugraph_ex_config": { - "enable": True, + "ascend_compilation_config": { + "enable_npugraph_ex": True, "enable_static_kernel": True }, "refresh": True } - npugraph_ex_config = init_ascend_config( - test_vllm_config).npugraph_ex_config - self.assertTrue(npugraph_ex_config.enable) - self.assertTrue(npugraph_ex_config.enable_static_kernel) + ascend_compilation_config = init_ascend_config( + test_vllm_config).ascend_compilation_config + self.assertTrue(ascend_compilation_config.enable_npugraph_ex) + self.assertTrue(ascend_compilation_config.enable_static_kernel) @_clean_up_ascend_config @patch("vllm_ascend.platform.NPUPlatform._fix_incompatible_config") diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 34e72d68..e524ad76 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -118,8 +118,6 @@ class AscendConfig: from vllm_ascend.utils import get_flashcomm2_config_and_validate self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_config_and_validate(self, vllm_config) - npugraph_ex_config = additional_config.get("npugraph_ex_config", {}) - self.npugraph_ex_config = NpugraphExConfig(**npugraph_ex_config) # We find that _npu_paged_attention still performs better than # npu_fused_infer_attention_score in some cases. We allow to execute # _npu_paged_attention in this cases. This should be removed once @@ -163,8 +161,8 @@ class AscendConfig: def update_compile_ranges_split_points(self): vllm_config = self.vllm_config - if self.npugraph_ex_config.enable: - if self.npugraph_ex_config.fuse_allreduce_rms: + if self.ascend_compilation_config.enable_npugraph_ex: + if self.ascend_compilation_config.fuse_allreduce_rms: from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THRESHOLD new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points @@ -253,56 +251,9 @@ class AscendCompilationConfig: deployed on Ascend platforms. """ - def __init__( - self, fuse_norm_quant: bool = True, fuse_qknorm_rope: bool = True, fuse_allreduce_rms: bool = False, **kwargs - ): - """ - Initialize the configuration. - - Args: - fuse_norm_quant (bool): Whether to enable norm and quant fusion optimization. - When set to True, the system will optimize norm and quant operations. - Default: True - fuse_qknorm_rope (bool): Whether to enable qknorm and rope fusion optimization. - Default: True - fuse_allreduce_rms (bool): Whether to enable allreduce and addrmsnorm fusion optimization. - Default: False - **kwargs: Additional optional parameters for forward compatibility and configuration extension. - """ - self.fuse_norm_quant = fuse_norm_quant - self.fuse_qknorm_rope = fuse_qknorm_rope - self.fuse_allreduce_rms = fuse_allreduce_rms - - -class AscendFusionConfig: - """ - Configuration for controlling whether to use a fused operator gmmswigluquant. - """ - - def __init__(self, fusion_ops_gmmswigluquant: bool = True, **kwargs): - """ - Initialize the configuration. - - Args: - fusion_ops_gmmswigluquant (bool): Whether to use a fused operator gmmswigluquant. - When set to True, the system will use a fused operator gmmswigluquant. - Default: True - **kwargs: Additional optional parameters for forward compatibility and configuration extension. - """ - self.fusion_ops_gmmswigluquant = fusion_ops_gmmswigluquant - - -class NpugraphExConfig: - """ - Configuration for controlling the behavior of npugraph_ex backend. - - This class provides a way to configure whether to use the npugraph_ex backend and static kernel. - These configurations can directly impact the performance and behavior of models deployed on Ascend platforms. - """ - def __init__( self, - enable: bool = True, + enable_npugraph_ex: bool = True, enable_static_kernel: bool = False, fuse_norm_quant: bool = True, fuse_qknorm_rope: bool = True, @@ -313,7 +264,7 @@ class NpugraphExConfig: Initialize the configuration. Args: - enable (bool): Whether to enable npugraph_ex backend. + enable_npugraph_ex (bool): Whether to enable npugraph_ex backend. When set to True, the Fx graph generated by Dymano will be optimized and compiled by the npugraph_ex backend. Default: True @@ -333,11 +284,32 @@ class NpugraphExConfig: Default: False **kwargs: Additional optional parameters for forward compatibility and configuration extension. """ - self.enable = enable - self.enable_static_kernel = enable_static_kernel self.fuse_norm_quant = fuse_norm_quant self.fuse_qknorm_rope = fuse_qknorm_rope self.fuse_allreduce_rms = fuse_allreduce_rms + self.enable_npugraph_ex = enable_npugraph_ex + self.enable_static_kernel = enable_static_kernel + self.fuse_muls_add = kwargs.get("fuse_muls_add", True) + if self.enable_static_kernel: + assert self.enable_npugraph_ex, "Static kernel generation requires npugraph_ex to be enabled." + + +class AscendFusionConfig: + """ + Configuration for controlling whether to use a fused operator gmmswigluquant. + """ + + def __init__(self, fusion_ops_gmmswigluquant: bool = True, **kwargs): + """ + Initialize the configuration. + + Args: + fusion_ops_gmmswigluquant (bool): Whether to use a fused operator gmmswigluquant. + When set to True, the system will use a fused operator gmmswigluquant. + Default: True + **kwargs: Additional optional parameters for forward compatibility and configuration extension. + """ + self.fusion_ops_gmmswigluquant = fusion_ops_gmmswigluquant class XliteGraphConfig: diff --git a/vllm_ascend/compilation/compiler_interface.py b/vllm_ascend/compilation/compiler_interface.py index 2c67a185..22b6f8a1 100644 --- a/vllm_ascend/compilation/compiler_interface.py +++ b/vllm_ascend/compilation/compiler_interface.py @@ -30,7 +30,7 @@ from vllm.compilation.compiler_interface import CompilerInterface from vllm.config import VllmConfig from vllm.config.utils import Range -from vllm_ascend.ascend_config import NpugraphExConfig, get_ascend_config +from vllm_ascend.ascend_config import AscendCompilationConfig, get_ascend_config from vllm_ascend.utils import COMPILATION_PASS_KEY @@ -71,7 +71,7 @@ def npugraph_ex_compile( example_inputs: list[Any], compiler_config: dict[str, Any], vllm_config: VllmConfig, - npugraph_ex_config: NpugraphExConfig, + ascend_compilation_config: AscendCompilationConfig, compile_range: Range, key: str | None = None, ) -> tuple[Callable | None, Any | None]: @@ -83,7 +83,7 @@ def npugraph_ex_compile( config.mode = "reduce-overhead" # execute FX graph in eager mode before graph mode to optimize FX graph. config.debug.run_eagerly = True - if npugraph_ex_config.enable_static_kernel: + if ascend_compilation_config.enable_static_kernel: config.experimental_config.aclgraph._aclnn_static_shape_kernel = True # According to the cudagraph_capture_size configuration, set the shapes # that can trigger the compilation of static kernel. If this configuration is @@ -117,8 +117,8 @@ class AscendCompiler(CompilerInterface): name = "AscendCompiler" def compute_hash(self, vllm_config: VllmConfig) -> str: - npugraph_ex_config = get_ascend_config().npugraph_ex_config - if npugraph_ex_config.enable: + npugraph_ex_enabled = get_ascend_config().ascend_compilation_config.enable_npugraph_ex + if npugraph_ex_enabled: self.vllm_config = vllm_config return vllm_config.compute_hash() @@ -134,11 +134,11 @@ class AscendCompiler(CompilerInterface): # see https://github.com/pytorch/pytorch/issues/138980 graph = copy.deepcopy(graph) - npugraph_ex_config = get_ascend_config().npugraph_ex_config - if npugraph_ex_config.enable: + ascend_compilation_config = get_ascend_config().ascend_compilation_config + if ascend_compilation_config.enable_npugraph_ex: assert hasattr(self, "vllm_config") return npugraph_ex_compile( - graph, example_inputs, compiler_config, self.vllm_config, npugraph_ex_config, compile_range, key + graph, example_inputs, compiler_config, self.vllm_config, ascend_compilation_config, compile_range, key ) else: return fusion_pass_compile(graph, example_inputs, compiler_config, compile_range, key) diff --git a/vllm_ascend/compilation/graph_fusion_pass_manager.py b/vllm_ascend/compilation/graph_fusion_pass_manager.py index 43d23f37..40acb081 100644 --- a/vllm_ascend/compilation/graph_fusion_pass_manager.py +++ b/vllm_ascend/compilation/graph_fusion_pass_manager.py @@ -64,6 +64,11 @@ class GraphFusionPassManager: self.passes.append(MatmulAllReduceAddRMSNormPass(config)) + if self.ascend_compilation_config.get("fuse_muls_add", True): + from .passes.muls_add_pass import MulsAddFusionPass + + self.passes.append(MulsAddFusionPass(config)) + if config.compilation_config.pass_config.enable_sp: from .passes.sequence_parallelism import AscendSequenceParallelismPass diff --git a/vllm_ascend/compilation/passes/muls_add_pass.py b/vllm_ascend/compilation/passes/muls_add_pass.py new file mode 100644 index 00000000..3d4ed764 --- /dev/null +++ b/vllm_ascend/compilation/passes/muls_add_pass.py @@ -0,0 +1,117 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import annotations + +import torch +from torch._inductor.pattern_matcher import PatternMatcherPass +from vllm.config import VllmConfig +from vllm.config.compilation import Range +from vllm.logger import logger + +from vllm_ascend.compilation.passes.base_pattern import BasePattern +from vllm_ascend.utils import vllm_version_is + +if vllm_version_is("0.15.0"): + from vllm.compilation.vllm_inductor_pass import VllmInductorPass # type: ignore +else: + from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass + + +class MulsAddPattern(BasePattern): + """ + Pattern that matches an element-wise mul + add sequence: + tmp = x * scale + out = tmp + y + and replaces it with a call to the muls_add_triton kernel. + """ + + def __init__(self, vllm_config: VllmConfig, scale: float = 1.0): + super().__init__(vllm_config) + self.scale = scale + + def get_inputs(self) -> list[torch.Tensor]: + """ + Generate example inputs for the MulsAddPattern. + + The exact shapes are not important for pattern matching; they only + provide meta information for the pattern matcher. + """ + x = torch.randn(2, 2048, device="npu", dtype=self.dtype) + y = torch.randn(2, 2048, device="npu", dtype=self.dtype) + # Only tensor inputs are needed here. The scalar scale is stored on the + # pattern instance (self.scale) instead of being passed as an input. + return [x, y] + + def get_pattern(self): + def pattern(x: torch.Tensor, y: torch.Tensor): + """ + Pattern for element-wise x * scale + y. + """ + tmp = x * self.scale + out = tmp + y + return out + + return pattern + + def get_replacement(self): + def replacement(x: torch.Tensor, y: torch.Tensor): + """ + Replacement that calls the muls_add_triton kernel using the + class-level scalar self.scale. + """ + return torch.ops.vllm.muls_add(x, y, self.scale) + + return replacement + + +class MulsAddFusionPass(VllmInductorPass): + """ + A fusion pass that replaces simple element-wise x * scale + y patterns + with the Triton-based muls_add_triton kernel on Ascend. + """ + + def __init__(self, vllm_config: VllmConfig): + super().__init__(vllm_config) + self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(pass_name="muls_add_fusion_pass") + + # For now we enable this pass for all floating-point dtypes that the + # model is configured to use. + dtype = vllm_config.model_config.dtype + if dtype not in (torch.float16, torch.bfloat16, torch.float32): + logger.debug("MulsAdd fusion not enabled: unsupported dtype %s", dtype) + return + + # Currently we only register a single pattern instance with a fixed + # scalar scale value. If needed, multiple instances with different + # scales can be added here in the future. + MulsAddPattern(vllm_config, scale=1.0).register(self.pattern_match_passes) + + def __call__(self, graph: torch.fx.Graph) -> None: # type: ignore[override] + self.begin() + self.matched_count = self.pattern_match_passes.apply(graph) + logger.debug("Fused %s muls_add patterns", self.matched_count) + self.end_and_log() + + def is_applicable_for_range(self, compile_range: Range) -> bool: + """ + Check if the pass is applicable for the current configuration. + + For now, muls_add fusion is always allowed for the selected ranges. + This hook exists so that we can add more fine-grained range control + in the future if needed. + """ + return True diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py index 348c936a..b803cd42 100644 --- a/vllm_ascend/ops/register_custom_ops.py +++ b/vllm_ascend/ops/register_custom_ops.py @@ -15,6 +15,7 @@ from vllm.utils.torch_utils import direct_register_custom_op from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.ops.rotary_embedding import rope_forward_oot +from vllm_ascend.ops.triton.muls_add import muls_add_triton from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.utils import npu_stream_switch, prefetch_stream @@ -201,6 +202,14 @@ def _rope_forward_oot_impl_fake( return query, key +def _muls_add_impl_fake( + x: torch.Tensor, + y: torch.Tensor, + scale: float, +) -> torch.Tensor: + return torch.empty_like(x) + + direct_register_custom_op( op_name="maybe_chunk_residual", op_func=_maybe_chunk_residual_impl, @@ -272,3 +281,11 @@ direct_register_custom_op( mutates_args=[], dispatch_key="PrivateUse1", ) + +direct_register_custom_op( + op_name="muls_add", + op_func=muls_add_triton, + fake_impl=_muls_add_impl_fake, + mutates_args=[], + dispatch_key="PrivateUse1", +) diff --git a/vllm_ascend/ops/triton/muls_add.py b/vllm_ascend/ops/triton/muls_add.py new file mode 100644 index 00000000..a7dcf0c7 --- /dev/null +++ b/vllm_ascend/ops/triton/muls_add.py @@ -0,0 +1,57 @@ +import torch +from vllm.triton_utils import tl, triton + +from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num + + +@triton.jit +def muls_add_kernel( + x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + scale, # Scale factor. + n_elements, # Size of the vector. + n_blocks, # Total number of blocks. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. +): + pid = tl.program_id(axis=0) + num_programs = tl.num_programs(axis=0) + for block_id in range(pid, n_blocks, num_programs): + block_start = block_id * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x * scale + y + tl.store(output_ptr + offsets, output, mask=mask) + + +def muls_add_triton(x: torch.Tensor, y: torch.Tensor, scale: float) -> torch.Tensor: + assert x.shape == y.shape, "Input tensors must have the same shape." + hidden_size = x.shape[-1] + + n_elements = x.numel() + output = torch.empty_like(x) + + # Determine the number of vector cores available + num_cores = get_vectorcore_num() + + # Define block size + BLOCK_SIZE = max(hidden_size // 2, 1024) + + # Calculate the number of programs to launch + num_blocks = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE + num_programs = min(num_blocks, num_cores) + + # Launch the Triton kernel + muls_add_kernel[(num_programs,)]( + x, + y, + output, + scale, + n_elements, + num_blocks, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return output diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index c0e52984..2d161a42 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -279,7 +279,7 @@ class NPUPlatform(Platform): if compilation_config.cudagraph_mode == CUDAGraphMode.NONE: compilation_config.mode = CompilationMode.NONE - ascend_config.npugraph_ex_config.enable = False + ascend_config.ascend_compilation_config.enable_npugraph_ex = False elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE: logger.info("PIECEWISE compilation enabled on NPU. use_inductor not supported - using only ACL Graph mode") assert compilation_config.mode == CompilationMode.VLLM_COMPILE, ( @@ -299,7 +299,7 @@ class NPUPlatform(Platform): # not be detected in advance assert. compilation_config.splitting_ops.extend(["vllm::mla_forward"]) update_aclgraph_sizes(vllm_config) - ascend_config.npugraph_ex_config.enable = False + ascend_config.ascend_compilation_config.enable_npugraph_ex = False elif ( compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY or compilation_config.cudagraph_mode == CUDAGraphMode.FULL @@ -328,7 +328,7 @@ class NPUPlatform(Platform): ) compilation_config.cudagraph_mode = CUDAGraphMode.NONE compilation_config.mode = CompilationMode.NONE - ascend_config.npugraph_ex_config.enable = False + ascend_config.ascend_compilation_config.enable_npugraph_ex = False # TODO: Remove this check when ACL Graph supports ASCEND_LAUNCH_BLOCKING=1 # Then, we will have to discuss the error handling strategy and user experience diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index 5e65521f..76bee45c 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -138,8 +138,8 @@ class NPUWorker(WorkerBase): self.use_v2_model_runner = envs_vllm.VLLM_USE_V2_MODEL_RUNNER - npugraph_ex_config = get_ascend_config().npugraph_ex_config - if npugraph_ex_config.enable and npugraph_ex_config.enable_static_kernel: + ascend_compilation_config = get_ascend_config().ascend_compilation_config + if ascend_compilation_config.enable_npugraph_ex and ascend_compilation_config.enable_static_kernel: # Prevent duplicate triggers, execute the exit logic only once shutdown_request = False