[Triton][Config] Add muls_add triton kernel and refactor AscendCompilationConfig (#5518)
### What this PR does / why we need it?
Add muls_add triton kernel with related fusion pass. What's more, this
PR refactors `AscendCompilationConfig` and delete `NpugraphExConfig`.
### Does this PR introduce _any_ user-facing change?
None
### How was this patch tested?
CI passed with new added test.
- vLLM version: v0.13.0
- vLLM main:
45c1ca1ca1
---------
Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -118,8 +118,6 @@ class AscendConfig:
|
||||
from vllm_ascend.utils import get_flashcomm2_config_and_validate
|
||||
|
||||
self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_config_and_validate(self, vllm_config)
|
||||
npugraph_ex_config = additional_config.get("npugraph_ex_config", {})
|
||||
self.npugraph_ex_config = NpugraphExConfig(**npugraph_ex_config)
|
||||
# We find that _npu_paged_attention still performs better than
|
||||
# npu_fused_infer_attention_score in some cases. We allow to execute
|
||||
# _npu_paged_attention in this cases. This should be removed once
|
||||
@@ -163,8 +161,8 @@ class AscendConfig:
|
||||
|
||||
def update_compile_ranges_split_points(self):
|
||||
vllm_config = self.vllm_config
|
||||
if self.npugraph_ex_config.enable:
|
||||
if self.npugraph_ex_config.fuse_allreduce_rms:
|
||||
if self.ascend_compilation_config.enable_npugraph_ex:
|
||||
if self.ascend_compilation_config.fuse_allreduce_rms:
|
||||
from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THRESHOLD
|
||||
|
||||
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
|
||||
@@ -253,56 +251,9 @@ class AscendCompilationConfig:
|
||||
deployed on Ascend platforms.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, fuse_norm_quant: bool = True, fuse_qknorm_rope: bool = True, fuse_allreduce_rms: bool = False, **kwargs
|
||||
):
|
||||
"""
|
||||
Initialize the configuration.
|
||||
|
||||
Args:
|
||||
fuse_norm_quant (bool): Whether to enable norm and quant fusion optimization.
|
||||
When set to True, the system will optimize norm and quant operations.
|
||||
Default: True
|
||||
fuse_qknorm_rope (bool): Whether to enable qknorm and rope fusion optimization.
|
||||
Default: True
|
||||
fuse_allreduce_rms (bool): Whether to enable allreduce and addrmsnorm fusion optimization.
|
||||
Default: False
|
||||
**kwargs: Additional optional parameters for forward compatibility and configuration extension.
|
||||
"""
|
||||
self.fuse_norm_quant = fuse_norm_quant
|
||||
self.fuse_qknorm_rope = fuse_qknorm_rope
|
||||
self.fuse_allreduce_rms = fuse_allreduce_rms
|
||||
|
||||
|
||||
class AscendFusionConfig:
|
||||
"""
|
||||
Configuration for controlling whether to use a fused operator gmmswigluquant.
|
||||
"""
|
||||
|
||||
def __init__(self, fusion_ops_gmmswigluquant: bool = True, **kwargs):
|
||||
"""
|
||||
Initialize the configuration.
|
||||
|
||||
Args:
|
||||
fusion_ops_gmmswigluquant (bool): Whether to use a fused operator gmmswigluquant.
|
||||
When set to True, the system will use a fused operator gmmswigluquant.
|
||||
Default: True
|
||||
**kwargs: Additional optional parameters for forward compatibility and configuration extension.
|
||||
"""
|
||||
self.fusion_ops_gmmswigluquant = fusion_ops_gmmswigluquant
|
||||
|
||||
|
||||
class NpugraphExConfig:
|
||||
"""
|
||||
Configuration for controlling the behavior of npugraph_ex backend.
|
||||
|
||||
This class provides a way to configure whether to use the npugraph_ex backend and static kernel.
|
||||
These configurations can directly impact the performance and behavior of models deployed on Ascend platforms.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enable: bool = True,
|
||||
enable_npugraph_ex: bool = True,
|
||||
enable_static_kernel: bool = False,
|
||||
fuse_norm_quant: bool = True,
|
||||
fuse_qknorm_rope: bool = True,
|
||||
@@ -313,7 +264,7 @@ class NpugraphExConfig:
|
||||
Initialize the configuration.
|
||||
|
||||
Args:
|
||||
enable (bool): Whether to enable npugraph_ex backend.
|
||||
enable_npugraph_ex (bool): Whether to enable npugraph_ex backend.
|
||||
When set to True, the Fx graph generated by Dymano will be
|
||||
optimized and compiled by the npugraph_ex backend.
|
||||
Default: True
|
||||
@@ -333,11 +284,32 @@ class NpugraphExConfig:
|
||||
Default: False
|
||||
**kwargs: Additional optional parameters for forward compatibility and configuration extension.
|
||||
"""
|
||||
self.enable = enable
|
||||
self.enable_static_kernel = enable_static_kernel
|
||||
self.fuse_norm_quant = fuse_norm_quant
|
||||
self.fuse_qknorm_rope = fuse_qknorm_rope
|
||||
self.fuse_allreduce_rms = fuse_allreduce_rms
|
||||
self.enable_npugraph_ex = enable_npugraph_ex
|
||||
self.enable_static_kernel = enable_static_kernel
|
||||
self.fuse_muls_add = kwargs.get("fuse_muls_add", True)
|
||||
if self.enable_static_kernel:
|
||||
assert self.enable_npugraph_ex, "Static kernel generation requires npugraph_ex to be enabled."
|
||||
|
||||
|
||||
class AscendFusionConfig:
|
||||
"""
|
||||
Configuration for controlling whether to use a fused operator gmmswigluquant.
|
||||
"""
|
||||
|
||||
def __init__(self, fusion_ops_gmmswigluquant: bool = True, **kwargs):
|
||||
"""
|
||||
Initialize the configuration.
|
||||
|
||||
Args:
|
||||
fusion_ops_gmmswigluquant (bool): Whether to use a fused operator gmmswigluquant.
|
||||
When set to True, the system will use a fused operator gmmswigluquant.
|
||||
Default: True
|
||||
**kwargs: Additional optional parameters for forward compatibility and configuration extension.
|
||||
"""
|
||||
self.fusion_ops_gmmswigluquant = fusion_ops_gmmswigluquant
|
||||
|
||||
|
||||
class XliteGraphConfig:
|
||||
|
||||
@@ -30,7 +30,7 @@ from vllm.compilation.compiler_interface import CompilerInterface
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
|
||||
from vllm_ascend.ascend_config import NpugraphExConfig, get_ascend_config
|
||||
from vllm_ascend.ascend_config import AscendCompilationConfig, get_ascend_config
|
||||
from vllm_ascend.utils import COMPILATION_PASS_KEY
|
||||
|
||||
|
||||
@@ -71,7 +71,7 @@ def npugraph_ex_compile(
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
vllm_config: VllmConfig,
|
||||
npugraph_ex_config: NpugraphExConfig,
|
||||
ascend_compilation_config: AscendCompilationConfig,
|
||||
compile_range: Range,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable | None, Any | None]:
|
||||
@@ -83,7 +83,7 @@ def npugraph_ex_compile(
|
||||
config.mode = "reduce-overhead"
|
||||
# execute FX graph in eager mode before graph mode to optimize FX graph.
|
||||
config.debug.run_eagerly = True
|
||||
if npugraph_ex_config.enable_static_kernel:
|
||||
if ascend_compilation_config.enable_static_kernel:
|
||||
config.experimental_config.aclgraph._aclnn_static_shape_kernel = True
|
||||
# According to the cudagraph_capture_size configuration, set the shapes
|
||||
# that can trigger the compilation of static kernel. If this configuration is
|
||||
@@ -117,8 +117,8 @@ class AscendCompiler(CompilerInterface):
|
||||
name = "AscendCompiler"
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
npugraph_ex_config = get_ascend_config().npugraph_ex_config
|
||||
if npugraph_ex_config.enable:
|
||||
npugraph_ex_enabled = get_ascend_config().ascend_compilation_config.enable_npugraph_ex
|
||||
if npugraph_ex_enabled:
|
||||
self.vllm_config = vllm_config
|
||||
return vllm_config.compute_hash()
|
||||
|
||||
@@ -134,11 +134,11 @@ class AscendCompiler(CompilerInterface):
|
||||
# see https://github.com/pytorch/pytorch/issues/138980
|
||||
graph = copy.deepcopy(graph)
|
||||
|
||||
npugraph_ex_config = get_ascend_config().npugraph_ex_config
|
||||
if npugraph_ex_config.enable:
|
||||
ascend_compilation_config = get_ascend_config().ascend_compilation_config
|
||||
if ascend_compilation_config.enable_npugraph_ex:
|
||||
assert hasattr(self, "vllm_config")
|
||||
return npugraph_ex_compile(
|
||||
graph, example_inputs, compiler_config, self.vllm_config, npugraph_ex_config, compile_range, key
|
||||
graph, example_inputs, compiler_config, self.vllm_config, ascend_compilation_config, compile_range, key
|
||||
)
|
||||
else:
|
||||
return fusion_pass_compile(graph, example_inputs, compiler_config, compile_range, key)
|
||||
|
||||
@@ -64,6 +64,11 @@ class GraphFusionPassManager:
|
||||
|
||||
self.passes.append(MatmulAllReduceAddRMSNormPass(config))
|
||||
|
||||
if self.ascend_compilation_config.get("fuse_muls_add", True):
|
||||
from .passes.muls_add_pass import MulsAddFusionPass
|
||||
|
||||
self.passes.append(MulsAddFusionPass(config))
|
||||
|
||||
if config.compilation_config.pass_config.enable_sp:
|
||||
from .passes.sequence_parallelism import AscendSequenceParallelismPass
|
||||
|
||||
|
||||
117
vllm_ascend/compilation/passes/muls_add_pass.py
Normal file
117
vllm_ascend/compilation/passes/muls_add_pass.py
Normal file
@@ -0,0 +1,117 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import Range
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.compilation.passes.base_pattern import BasePattern
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
if vllm_version_is("0.15.0"):
|
||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass # type: ignore
|
||||
else:
|
||||
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
|
||||
class MulsAddPattern(BasePattern):
|
||||
"""
|
||||
Pattern that matches an element-wise mul + add sequence:
|
||||
tmp = x * scale
|
||||
out = tmp + y
|
||||
and replaces it with a call to the muls_add_triton kernel.
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, scale: float = 1.0):
|
||||
super().__init__(vllm_config)
|
||||
self.scale = scale
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
"""
|
||||
Generate example inputs for the MulsAddPattern.
|
||||
|
||||
The exact shapes are not important for pattern matching; they only
|
||||
provide meta information for the pattern matcher.
|
||||
"""
|
||||
x = torch.randn(2, 2048, device="npu", dtype=self.dtype)
|
||||
y = torch.randn(2, 2048, device="npu", dtype=self.dtype)
|
||||
# Only tensor inputs are needed here. The scalar scale is stored on the
|
||||
# pattern instance (self.scale) instead of being passed as an input.
|
||||
return [x, y]
|
||||
|
||||
def get_pattern(self):
|
||||
def pattern(x: torch.Tensor, y: torch.Tensor):
|
||||
"""
|
||||
Pattern for element-wise x * scale + y.
|
||||
"""
|
||||
tmp = x * self.scale
|
||||
out = tmp + y
|
||||
return out
|
||||
|
||||
return pattern
|
||||
|
||||
def get_replacement(self):
|
||||
def replacement(x: torch.Tensor, y: torch.Tensor):
|
||||
"""
|
||||
Replacement that calls the muls_add_triton kernel using the
|
||||
class-level scalar self.scale.
|
||||
"""
|
||||
return torch.ops.vllm.muls_add(x, y, self.scale)
|
||||
|
||||
return replacement
|
||||
|
||||
|
||||
class MulsAddFusionPass(VllmInductorPass):
|
||||
"""
|
||||
A fusion pass that replaces simple element-wise x * scale + y patterns
|
||||
with the Triton-based muls_add_triton kernel on Ascend.
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
super().__init__(vllm_config)
|
||||
self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(pass_name="muls_add_fusion_pass")
|
||||
|
||||
# For now we enable this pass for all floating-point dtypes that the
|
||||
# model is configured to use.
|
||||
dtype = vllm_config.model_config.dtype
|
||||
if dtype not in (torch.float16, torch.bfloat16, torch.float32):
|
||||
logger.debug("MulsAdd fusion not enabled: unsupported dtype %s", dtype)
|
||||
return
|
||||
|
||||
# Currently we only register a single pattern instance with a fixed
|
||||
# scalar scale value. If needed, multiple instances with different
|
||||
# scales can be added here in the future.
|
||||
MulsAddPattern(vllm_config, scale=1.0).register(self.pattern_match_passes)
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph) -> None: # type: ignore[override]
|
||||
self.begin()
|
||||
self.matched_count = self.pattern_match_passes.apply(graph)
|
||||
logger.debug("Fused %s muls_add patterns", self.matched_count)
|
||||
self.end_and_log()
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
"""
|
||||
Check if the pass is applicable for the current configuration.
|
||||
|
||||
For now, muls_add fusion is always allowed for the selected ranges.
|
||||
This hook exists so that we can add more fine-grained range control
|
||||
in the future if needed.
|
||||
"""
|
||||
return True
|
||||
@@ -15,6 +15,7 @@ from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.ops.rotary_embedding import rope_forward_oot
|
||||
from vllm_ascend.ops.triton.muls_add import muls_add_triton
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.utils import npu_stream_switch, prefetch_stream
|
||||
|
||||
@@ -201,6 +202,14 @@ def _rope_forward_oot_impl_fake(
|
||||
return query, key
|
||||
|
||||
|
||||
def _muls_add_impl_fake(
|
||||
x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
scale: float,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="maybe_chunk_residual",
|
||||
op_func=_maybe_chunk_residual_impl,
|
||||
@@ -272,3 +281,11 @@ direct_register_custom_op(
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1",
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="muls_add",
|
||||
op_func=muls_add_triton,
|
||||
fake_impl=_muls_add_impl_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1",
|
||||
)
|
||||
|
||||
57
vllm_ascend/ops/triton/muls_add.py
Normal file
57
vllm_ascend/ops/triton/muls_add.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
||||
|
||||
|
||||
@triton.jit
|
||||
def muls_add_kernel(
|
||||
x_ptr, # *Pointer* to first input vector.
|
||||
y_ptr, # *Pointer* to second input vector.
|
||||
output_ptr, # *Pointer* to output vector.
|
||||
scale, # Scale factor.
|
||||
n_elements, # Size of the vector.
|
||||
n_blocks, # Total number of blocks.
|
||||
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
num_programs = tl.num_programs(axis=0)
|
||||
for block_id in range(pid, n_blocks, num_programs):
|
||||
block_start = block_id * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
y = tl.load(y_ptr + offsets, mask=mask)
|
||||
output = x * scale + y
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
|
||||
def muls_add_triton(x: torch.Tensor, y: torch.Tensor, scale: float) -> torch.Tensor:
|
||||
assert x.shape == y.shape, "Input tensors must have the same shape."
|
||||
hidden_size = x.shape[-1]
|
||||
|
||||
n_elements = x.numel()
|
||||
output = torch.empty_like(x)
|
||||
|
||||
# Determine the number of vector cores available
|
||||
num_cores = get_vectorcore_num()
|
||||
|
||||
# Define block size
|
||||
BLOCK_SIZE = max(hidden_size // 2, 1024)
|
||||
|
||||
# Calculate the number of programs to launch
|
||||
num_blocks = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE
|
||||
num_programs = min(num_blocks, num_cores)
|
||||
|
||||
# Launch the Triton kernel
|
||||
muls_add_kernel[(num_programs,)](
|
||||
x,
|
||||
y,
|
||||
output,
|
||||
scale,
|
||||
n_elements,
|
||||
num_blocks,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
return output
|
||||
@@ -279,7 +279,7 @@ class NPUPlatform(Platform):
|
||||
|
||||
if compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
|
||||
compilation_config.mode = CompilationMode.NONE
|
||||
ascend_config.npugraph_ex_config.enable = False
|
||||
ascend_config.ascend_compilation_config.enable_npugraph_ex = False
|
||||
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE:
|
||||
logger.info("PIECEWISE compilation enabled on NPU. use_inductor not supported - using only ACL Graph mode")
|
||||
assert compilation_config.mode == CompilationMode.VLLM_COMPILE, (
|
||||
@@ -299,7 +299,7 @@ class NPUPlatform(Platform):
|
||||
# not be detected in advance assert.
|
||||
compilation_config.splitting_ops.extend(["vllm::mla_forward"])
|
||||
update_aclgraph_sizes(vllm_config)
|
||||
ascend_config.npugraph_ex_config.enable = False
|
||||
ascend_config.ascend_compilation_config.enable_npugraph_ex = False
|
||||
elif (
|
||||
compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY
|
||||
or compilation_config.cudagraph_mode == CUDAGraphMode.FULL
|
||||
@@ -328,7 +328,7 @@ class NPUPlatform(Platform):
|
||||
)
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
compilation_config.mode = CompilationMode.NONE
|
||||
ascend_config.npugraph_ex_config.enable = False
|
||||
ascend_config.ascend_compilation_config.enable_npugraph_ex = False
|
||||
|
||||
# TODO: Remove this check when ACL Graph supports ASCEND_LAUNCH_BLOCKING=1
|
||||
# Then, we will have to discuss the error handling strategy and user experience
|
||||
|
||||
@@ -138,8 +138,8 @@ class NPUWorker(WorkerBase):
|
||||
|
||||
self.use_v2_model_runner = envs_vllm.VLLM_USE_V2_MODEL_RUNNER
|
||||
|
||||
npugraph_ex_config = get_ascend_config().npugraph_ex_config
|
||||
if npugraph_ex_config.enable and npugraph_ex_config.enable_static_kernel:
|
||||
ascend_compilation_config = get_ascend_config().ascend_compilation_config
|
||||
if ascend_compilation_config.enable_npugraph_ex and ascend_compilation_config.enable_static_kernel:
|
||||
# Prevent duplicate triggers, execute the exit logic only once
|
||||
shutdown_request = False
|
||||
|
||||
|
||||
Reference in New Issue
Block a user