[Triton][Config] Add muls_add triton kernel and refactor AscendCompilationConfig (#5518)
### What this PR does / why we need it?
Add muls_add triton kernel with related fusion pass. What's more, this
PR refactors `AscendCompilationConfig` and delete `NpugraphExConfig`.
### Does this PR introduce _any_ user-facing change?
None
### How was this patch tested?
CI passed with new added test.
- vLLM version: v0.13.0
- vLLM main:
45c1ca1ca1
---------
Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -118,8 +118,6 @@ class AscendConfig:
|
||||
from vllm_ascend.utils import get_flashcomm2_config_and_validate
|
||||
|
||||
self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_config_and_validate(self, vllm_config)
|
||||
npugraph_ex_config = additional_config.get("npugraph_ex_config", {})
|
||||
self.npugraph_ex_config = NpugraphExConfig(**npugraph_ex_config)
|
||||
# We find that _npu_paged_attention still performs better than
|
||||
# npu_fused_infer_attention_score in some cases. We allow to execute
|
||||
# _npu_paged_attention in this cases. This should be removed once
|
||||
@@ -163,8 +161,8 @@ class AscendConfig:
|
||||
|
||||
def update_compile_ranges_split_points(self):
|
||||
vllm_config = self.vllm_config
|
||||
if self.npugraph_ex_config.enable:
|
||||
if self.npugraph_ex_config.fuse_allreduce_rms:
|
||||
if self.ascend_compilation_config.enable_npugraph_ex:
|
||||
if self.ascend_compilation_config.fuse_allreduce_rms:
|
||||
from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THRESHOLD
|
||||
|
||||
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
|
||||
@@ -253,56 +251,9 @@ class AscendCompilationConfig:
|
||||
deployed on Ascend platforms.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, fuse_norm_quant: bool = True, fuse_qknorm_rope: bool = True, fuse_allreduce_rms: bool = False, **kwargs
|
||||
):
|
||||
"""
|
||||
Initialize the configuration.
|
||||
|
||||
Args:
|
||||
fuse_norm_quant (bool): Whether to enable norm and quant fusion optimization.
|
||||
When set to True, the system will optimize norm and quant operations.
|
||||
Default: True
|
||||
fuse_qknorm_rope (bool): Whether to enable qknorm and rope fusion optimization.
|
||||
Default: True
|
||||
fuse_allreduce_rms (bool): Whether to enable allreduce and addrmsnorm fusion optimization.
|
||||
Default: False
|
||||
**kwargs: Additional optional parameters for forward compatibility and configuration extension.
|
||||
"""
|
||||
self.fuse_norm_quant = fuse_norm_quant
|
||||
self.fuse_qknorm_rope = fuse_qknorm_rope
|
||||
self.fuse_allreduce_rms = fuse_allreduce_rms
|
||||
|
||||
|
||||
class AscendFusionConfig:
|
||||
"""
|
||||
Configuration for controlling whether to use a fused operator gmmswigluquant.
|
||||
"""
|
||||
|
||||
def __init__(self, fusion_ops_gmmswigluquant: bool = True, **kwargs):
|
||||
"""
|
||||
Initialize the configuration.
|
||||
|
||||
Args:
|
||||
fusion_ops_gmmswigluquant (bool): Whether to use a fused operator gmmswigluquant.
|
||||
When set to True, the system will use a fused operator gmmswigluquant.
|
||||
Default: True
|
||||
**kwargs: Additional optional parameters for forward compatibility and configuration extension.
|
||||
"""
|
||||
self.fusion_ops_gmmswigluquant = fusion_ops_gmmswigluquant
|
||||
|
||||
|
||||
class NpugraphExConfig:
|
||||
"""
|
||||
Configuration for controlling the behavior of npugraph_ex backend.
|
||||
|
||||
This class provides a way to configure whether to use the npugraph_ex backend and static kernel.
|
||||
These configurations can directly impact the performance and behavior of models deployed on Ascend platforms.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enable: bool = True,
|
||||
enable_npugraph_ex: bool = True,
|
||||
enable_static_kernel: bool = False,
|
||||
fuse_norm_quant: bool = True,
|
||||
fuse_qknorm_rope: bool = True,
|
||||
@@ -313,7 +264,7 @@ class NpugraphExConfig:
|
||||
Initialize the configuration.
|
||||
|
||||
Args:
|
||||
enable (bool): Whether to enable npugraph_ex backend.
|
||||
enable_npugraph_ex (bool): Whether to enable npugraph_ex backend.
|
||||
When set to True, the Fx graph generated by Dymano will be
|
||||
optimized and compiled by the npugraph_ex backend.
|
||||
Default: True
|
||||
@@ -333,11 +284,32 @@ class NpugraphExConfig:
|
||||
Default: False
|
||||
**kwargs: Additional optional parameters for forward compatibility and configuration extension.
|
||||
"""
|
||||
self.enable = enable
|
||||
self.enable_static_kernel = enable_static_kernel
|
||||
self.fuse_norm_quant = fuse_norm_quant
|
||||
self.fuse_qknorm_rope = fuse_qknorm_rope
|
||||
self.fuse_allreduce_rms = fuse_allreduce_rms
|
||||
self.enable_npugraph_ex = enable_npugraph_ex
|
||||
self.enable_static_kernel = enable_static_kernel
|
||||
self.fuse_muls_add = kwargs.get("fuse_muls_add", True)
|
||||
if self.enable_static_kernel:
|
||||
assert self.enable_npugraph_ex, "Static kernel generation requires npugraph_ex to be enabled."
|
||||
|
||||
|
||||
class AscendFusionConfig:
|
||||
"""
|
||||
Configuration for controlling whether to use a fused operator gmmswigluquant.
|
||||
"""
|
||||
|
||||
def __init__(self, fusion_ops_gmmswigluquant: bool = True, **kwargs):
|
||||
"""
|
||||
Initialize the configuration.
|
||||
|
||||
Args:
|
||||
fusion_ops_gmmswigluquant (bool): Whether to use a fused operator gmmswigluquant.
|
||||
When set to True, the system will use a fused operator gmmswigluquant.
|
||||
Default: True
|
||||
**kwargs: Additional optional parameters for forward compatibility and configuration extension.
|
||||
"""
|
||||
self.fusion_ops_gmmswigluquant = fusion_ops_gmmswigluquant
|
||||
|
||||
|
||||
class XliteGraphConfig:
|
||||
|
||||
Reference in New Issue
Block a user