[bugfix][npugraph_ex]add the extra check for allreduce rmsnorm fusion pass (#6430)
### What this PR does / why we need it?
Allreduce rmsnorm fusion pass has an additional check condition, which
requires fusion of the Fx graph only when the start of compile_range is
greater than 512. We previously overlooked this check.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.14.1
- vLLM main:
dc917cceb8
---------
Signed-off-by: chencangtao <chencangtao@huawei.com>
Co-authored-by: chencangtao <chencangtao@huawei.com>
This commit is contained in:
@@ -16,6 +16,8 @@
|
||||
#
|
||||
import torch
|
||||
import torchair
|
||||
from torch._inductor.pattern_matcher import Match
|
||||
from vllm.compilation.inductor_pass import get_pass_context
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import Range
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce
|
||||
@@ -27,6 +29,11 @@ from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check im
|
||||
ALLREDUCE_NORM_FUSE_THREHOLD = 512
|
||||
|
||||
|
||||
def extra_check_for_allreduce_rmsnorm_fusion_pass(match: Match) -> bool:
|
||||
compile_range = get_pass_context().compile_range
|
||||
return extra_stream_scope_check(match) and compile_range.start > ALLREDUCE_NORM_FUSE_THREHOLD
|
||||
|
||||
|
||||
class GraphEXMiddleLayerMatmulAllReduceAddRMSNormPattern:
|
||||
"""
|
||||
recognizing the Matmul + AllReduce + AddRMSNorm computation pattern
|
||||
@@ -80,7 +87,7 @@ class GraphEXMiddleLayerMatmulAllReduceAddRMSNormPattern:
|
||||
search_fn=pattern,
|
||||
replace_fn=replacement,
|
||||
example_inputs=self.get_inputs(),
|
||||
extra_check=extra_stream_scope_check,
|
||||
extra_check=extra_check_for_allreduce_rmsnorm_fusion_pass,
|
||||
)
|
||||
|
||||
|
||||
@@ -130,7 +137,7 @@ class GraphEXLastLayerMatmulAllReduceAddRMSNormPattern:
|
||||
search_fn=pattern,
|
||||
replace_fn=replacement,
|
||||
example_inputs=self.get_inputs(),
|
||||
extra_check=extra_stream_scope_check,
|
||||
extra_check=extra_check_for_allreduce_rmsnorm_fusion_pass,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user