[bugfix][npugraph_ex]add the extra check for allreduce rmsnorm fusion pass (#6430)
### What this PR does / why we need it?
Allreduce rmsnorm fusion pass has an additional check condition, which
requires fusion of the Fx graph only when the start of compile_range is
greater than 512. We previously overlooked this check.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.14.1
- vLLM main:
dc917cceb8
---------
Signed-off-by: chencangtao <chencangtao@huawei.com>
Co-authored-by: chencangtao <chencangtao@huawei.com>
This commit is contained in:
@@ -16,6 +16,8 @@
|
|||||||
#
|
#
|
||||||
import torch
|
import torch
|
||||||
import torchair
|
import torchair
|
||||||
|
from torch._inductor.pattern_matcher import Match
|
||||||
|
from vllm.compilation.inductor_pass import get_pass_context
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.config.compilation import Range
|
from vllm.config.compilation import Range
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce
|
from vllm.distributed import get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce
|
||||||
@@ -27,6 +29,11 @@ from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check im
|
|||||||
ALLREDUCE_NORM_FUSE_THREHOLD = 512
|
ALLREDUCE_NORM_FUSE_THREHOLD = 512
|
||||||
|
|
||||||
|
|
||||||
|
def extra_check_for_allreduce_rmsnorm_fusion_pass(match: Match) -> bool:
|
||||||
|
compile_range = get_pass_context().compile_range
|
||||||
|
return extra_stream_scope_check(match) and compile_range.start > ALLREDUCE_NORM_FUSE_THREHOLD
|
||||||
|
|
||||||
|
|
||||||
class GraphEXMiddleLayerMatmulAllReduceAddRMSNormPattern:
|
class GraphEXMiddleLayerMatmulAllReduceAddRMSNormPattern:
|
||||||
"""
|
"""
|
||||||
recognizing the Matmul + AllReduce + AddRMSNorm computation pattern
|
recognizing the Matmul + AllReduce + AddRMSNorm computation pattern
|
||||||
@@ -80,7 +87,7 @@ class GraphEXMiddleLayerMatmulAllReduceAddRMSNormPattern:
|
|||||||
search_fn=pattern,
|
search_fn=pattern,
|
||||||
replace_fn=replacement,
|
replace_fn=replacement,
|
||||||
example_inputs=self.get_inputs(),
|
example_inputs=self.get_inputs(),
|
||||||
extra_check=extra_stream_scope_check,
|
extra_check=extra_check_for_allreduce_rmsnorm_fusion_pass,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -130,7 +137,7 @@ class GraphEXLastLayerMatmulAllReduceAddRMSNormPattern:
|
|||||||
search_fn=pattern,
|
search_fn=pattern,
|
||||||
replace_fn=replacement,
|
replace_fn=replacement,
|
||||||
example_inputs=self.get_inputs(),
|
example_inputs=self.get_inputs(),
|
||||||
extra_check=extra_stream_scope_check,
|
extra_check=extra_check_for_allreduce_rmsnorm_fusion_pass,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user