[Fusion] [Graph]Add Matmul Allreduce Rmsnorm fusion Pass (#5034)
This PR add `MatmulAllreduceRmsnorm` operator and introduces a graph
fusion pass for `matmul_allreduce_rmsnorm` operations. The
implementation includes a new configuration flag, a pattern matching
pass using `torch._inductor.pattern_matcher`.
Co-authored-by: Trunrain [270250579@qq.com](mailto:270250579@qq.com)
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: tongrunze <t00574058@china.huawei.com>
This commit is contained in:
@@ -192,6 +192,18 @@ class NPUPlatform(Platform):
|
||||
else ascend_compilation_config
|
||||
)
|
||||
|
||||
if vllm_config.additional_config.get("ascend_compilation_config", {}).get("fuse_allreduce_rms", True):
|
||||
from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THREHOLD
|
||||
|
||||
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
|
||||
new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THREHOLD)
|
||||
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
|
||||
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
|
||||
logger.debug(
|
||||
"set compile_ranges_split_points to "
|
||||
"{new_compile_ranges_split_points} for matmul and allreduce fusion"
|
||||
)
|
||||
|
||||
elif model_config and hasattr(model_config.hf_text_config, "index_topk"):
|
||||
vllm_config.cache_config.cache_dtype = str(model_config.dtype).replace("torch.", "")
|
||||
if model_config is None:
|
||||
|
||||
Reference in New Issue
Block a user