[Feat]support sequence parallelism by pass for VL models (#5632)

This commit is contained in:
realliujiaxu
2026-02-27 08:27:41 +08:00
committed by GitHub
parent ed175d6d92
commit 5def28dcd3
22 changed files with 460 additions and 101 deletions

View File

@@ -724,6 +724,19 @@ def matmul_allreduce_enable() -> bool:
return envs_ascend.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE
def enable_flash_comm_v1():
return (
envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1
# Flash comm 1 should be enabled by env VLLM_ASCEND_ENABLE_FLASHCOMM1
# We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility.
or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0")))
)
def enable_sp_by_pass(vllm_config: VllmConfig):
return not vllm_config.model_config.enforce_eager and vllm_config.compilation_config.pass_config.enable_sp
def enable_sp(vllm_config=None, enable_shared_expert_dp: bool = False) -> bool:
global _ENABLE_SP
if _ENABLE_SP is None:
@@ -731,29 +744,12 @@ def enable_sp(vllm_config=None, enable_shared_expert_dp: bool = False) -> bool:
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
_ENABLE_SP = (
vllm_config.compilation_config.pass_config.enable_sp
or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1
# Flash comm 1 should be enabled by env VLLM_ASCEND_ENABLE_FLASHCOMM1
# We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility.
or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0")))
)
_ENABLE_SP = enable_sp_by_pass(vllm_config) or enable_flash_comm_v1()
if not _ENABLE_SP and enable_shared_expert_dp:
_ENABLE_SP = True
logger.info("shared_expert_dp requires enable_sp = True. has set enable_sp to True")
if not _ENABLE_SP:
return _ENABLE_SP
assert vllm_config.parallel_config.tensor_parallel_size > 1, (
"Flash Comm v1 (Sequence Parallelism) is only supported when tp_size > 1."
)
assert not is_moe_model(vllm_config) or vllm_config.parallel_config.enable_expert_parallel, (
"Flash Comm v1 (Sequence Parallelism) requires enable_expert_parallel=True for MoE models."
)
return _ENABLE_SP
@@ -1113,7 +1109,7 @@ def enable_dsa_cp() -> bool:
is_ds_v32 = hasattr(vllm_config.model_config, "hf_text_config") and hasattr(
vllm_config.model_config.hf_text_config, "index_topk"
)
return bool(is_ds_v32 and enable_sp())
return bool(is_ds_v32 and enable_flash_comm_v1())
@lru_cache(maxsize=1)