[Feat]support sequence parallelism by pass for VL models (#5632)

This commit is contained in:
realliujiaxu
2026-02-27 08:27:41 +08:00
committed by GitHub
parent ed175d6d92
commit 5def28dcd3
22 changed files with 460 additions and 101 deletions

View File

@@ -48,6 +48,7 @@ from vllm_ascend.utils import (
update_aclgraph_sizes,
update_cudagraph_capture_sizes,
is_310p,
enable_flash_comm_v1,
)
if TYPE_CHECKING:
@@ -198,32 +199,9 @@ class NPUPlatform(Platform):
if not isinstance(ascend_compilation_config, dict)
else ascend_compilation_config
)
ascend_config.update_compile_ranges_split_points()
if vllm_config.additional_config.get("ascend_compilation_config", {}).get("fuse_allreduce_rms", True):
from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THRESHOLD
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THRESHOLD)
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
logger.debug(
"set compile_ranges_split_points to "
"{new_compile_ranges_split_points} for matmul and allreduce fusion"
)
npugraph_ex_config = ascend_config.npugraph_ex_config
if npugraph_ex_config and npugraph_ex_config.fuse_allreduce_rms:
from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THRESHOLD
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THRESHOLD)
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
logger.debug(
"set compile_ranges_split_points to {new_compile_ranges_split_points} for matmul and allreduce fusion"
)
elif model_config and hasattr(model_config.hf_text_config, "index_topk"):
if model_config and hasattr(model_config.hf_text_config, "index_topk"):
vllm_config.cache_config.cache_dtype = str(model_config.dtype).replace("torch.", "")
ascend_fusion_config = ascend_config.ascend_fusion_config
@@ -417,15 +395,19 @@ class NPUPlatform(Platform):
)
vllm_config.parallel_config.cp_kv_cache_interleave_size = cache_config.block_size
if is_vl_model(vllm_config):
if bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0"))) or bool(
int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", "0"))
):
raise ValueError(
"Currently, VL models doesn't support "
"FLASHCOMM in vllm-ascend. We will fix this in the future. "
"Please set VLLM_ASCEND_ENABLE_FLASHCOMM1=0."
)
if enable_flash_comm_v1():
assert not is_vl_model(vllm_config), """Flash Comm V1 is not supported for VL models. \
Please disable it by setting VLLM_ASCEND_ENABLE_FLASHCOMM1=0. \
For optimal performance with VL models, we recommend enabling Sequence Parallelism \
via --compilation-config '{"pass_config": {"enable_sp": true}}'."""
assert vllm_config.parallel_config.tensor_parallel_size > 1, (
"Flash Comm v1 is only supported when tp_size > 1."
)
assert not is_moe_model(vllm_config) or vllm_config.parallel_config.enable_expert_parallel, (
"Flash Comm v1 requires enable_expert_parallel=True for MoE models."
)
# Set "PYTORCH_NPU_ALLOC_CONF=expandable_segments:True" by default to optimize NPU memory management.
# Find more details at https://docs.vllm.ai/projects/ascend/en/latest/faqs.html#how-to-handle-the-out-of-memory-issue
@@ -626,16 +608,16 @@ class NPUPlatform(Platform):
# communication methods.
mmrs_fusion = True
if is_moe_model(vllm_config):
sp_enabled = enable_sp(vllm_config) and num_tokens is not None
flash_comm_v1_enabled = enable_sp(vllm_config) and num_tokens is not None
mmrs_fusion = False
else:
sp_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000
flash_comm_v1_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000
# TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
flashcomm_v2_enabled = flashcomm2_enable() and tp_world_size > 1 and num_tokens is not None
pad_size = None
padded_length = None
if sp_enabled or flashcomm_v2_enabled:
if flash_comm_v1_enabled or flashcomm_v2_enabled:
pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size
if num_tokens is None and attn_metadata is not None:
@@ -643,7 +625,7 @@ class NPUPlatform(Platform):
dp_world_size = get_dp_group().world_size
if dp_world_size > 1 and dp_metadata is not None:
max_tokens_across_dp = dp_metadata.max_tokens_across_dp_cpu.item()
if sp_enabled or flashcomm_v2_enabled:
if flash_comm_v1_enabled or flashcomm_v2_enabled:
padded_length = (max_tokens_across_dp + tp_world_size - 1) // tp_world_size * tp_world_size
pad_size = padded_length - num_tokens
else:
@@ -664,7 +646,7 @@ class NPUPlatform(Platform):
"capturing": capturing,
"mmrs_fusion": mmrs_fusion,
"num_tokens": num_tokens,
"sp_enabled": sp_enabled,
"flash_comm_v1_enabled": flash_comm_v1_enabled,
"flashcomm_v2_enabled": flashcomm_v2_enabled,
"pad_size": pad_size,
"padded_length": padded_length,