[Feat]support sequence parallelism by pass for VL models (#5632)
This commit is contained in:
@@ -30,6 +30,7 @@ class AscendConfig:
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig"):
|
||||
self.vllm_config = vllm_config
|
||||
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
|
||||
|
||||
xlite_graph_config = additional_config.get("xlite_graph_config", {})
|
||||
@@ -160,6 +161,47 @@ class AscendConfig:
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
def update_compile_ranges_split_points(self):
|
||||
vllm_config = self.vllm_config
|
||||
if self.npugraph_ex_config.enable:
|
||||
if self.npugraph_ex_config.fuse_allreduce_rms:
|
||||
from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THRESHOLD
|
||||
|
||||
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
|
||||
new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THRESHOLD)
|
||||
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
|
||||
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
|
||||
logger.debug(
|
||||
"set compile_ranges_split_points to "
|
||||
"{new_compile_ranges_split_points} for matmul and allreduce fusion"
|
||||
)
|
||||
|
||||
else:
|
||||
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
|
||||
if vllm_config.additional_config.get("ascend_compilation_config", {}).get("fuse_allreduce_rms", True):
|
||||
from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THRESHOLD
|
||||
|
||||
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
|
||||
new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THRESHOLD)
|
||||
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
|
||||
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
|
||||
logger.debug(
|
||||
"set compile_ranges_split_points to "
|
||||
"{new_compile_ranges_split_points} for matmul and allreduce fusion"
|
||||
)
|
||||
|
||||
from vllm_ascend.utils import is_moe_model
|
||||
|
||||
if vllm_config.compilation_config.pass_config.enable_sp and not is_moe_model(vllm_config):
|
||||
from vllm_ascend.compilation.passes.sequence_parallelism import get_sp_threshold
|
||||
|
||||
sp_threshold = get_sp_threshold(vllm_config)
|
||||
new_compile_ranges_split_points.append(sp_threshold)
|
||||
logger.debug(f"add {sp_threshold} to compile_ranges_split_points for sequence parallelism")
|
||||
if len(new_compile_ranges_split_points) > len(vllm_config.compilation_config.compile_ranges_split_points):
|
||||
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
|
||||
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
|
||||
|
||||
|
||||
class FinegrainedTPConfig:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user