[Feat]support sequence parallelism by pass for VL models (#5632)
This commit is contained in:
@@ -724,6 +724,19 @@ def matmul_allreduce_enable() -> bool:
|
||||
return envs_ascend.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE
|
||||
|
||||
|
||||
def enable_flash_comm_v1():
|
||||
return (
|
||||
envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1
|
||||
# Flash comm 1 should be enabled by env VLLM_ASCEND_ENABLE_FLASHCOMM1
|
||||
# We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility.
|
||||
or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0")))
|
||||
)
|
||||
|
||||
|
||||
def enable_sp_by_pass(vllm_config: VllmConfig):
|
||||
return not vllm_config.model_config.enforce_eager and vllm_config.compilation_config.pass_config.enable_sp
|
||||
|
||||
|
||||
def enable_sp(vllm_config=None, enable_shared_expert_dp: bool = False) -> bool:
|
||||
global _ENABLE_SP
|
||||
if _ENABLE_SP is None:
|
||||
@@ -731,29 +744,12 @@ def enable_sp(vllm_config=None, enable_shared_expert_dp: bool = False) -> bool:
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
_ENABLE_SP = (
|
||||
vllm_config.compilation_config.pass_config.enable_sp
|
||||
or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1
|
||||
# Flash comm 1 should be enabled by env VLLM_ASCEND_ENABLE_FLASHCOMM1
|
||||
# We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility.
|
||||
or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0")))
|
||||
)
|
||||
_ENABLE_SP = enable_sp_by_pass(vllm_config) or enable_flash_comm_v1()
|
||||
|
||||
if not _ENABLE_SP and enable_shared_expert_dp:
|
||||
_ENABLE_SP = True
|
||||
logger.info("shared_expert_dp requires enable_sp = True. has set enable_sp to True")
|
||||
|
||||
if not _ENABLE_SP:
|
||||
return _ENABLE_SP
|
||||
|
||||
assert vllm_config.parallel_config.tensor_parallel_size > 1, (
|
||||
"Flash Comm v1 (Sequence Parallelism) is only supported when tp_size > 1."
|
||||
)
|
||||
|
||||
assert not is_moe_model(vllm_config) or vllm_config.parallel_config.enable_expert_parallel, (
|
||||
"Flash Comm v1 (Sequence Parallelism) requires enable_expert_parallel=True for MoE models."
|
||||
)
|
||||
|
||||
return _ENABLE_SP
|
||||
|
||||
|
||||
@@ -1113,7 +1109,7 @@ def enable_dsa_cp() -> bool:
|
||||
is_ds_v32 = hasattr(vllm_config.model_config, "hf_text_config") and hasattr(
|
||||
vllm_config.model_config.hf_text_config, "index_topk"
|
||||
)
|
||||
return bool(is_ds_v32 and enable_sp())
|
||||
return bool(is_ds_v32 and enable_flash_comm_v1())
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
|
||||
Reference in New Issue
Block a user