diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 80e6541e..a0f1edd5 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -115,12 +115,10 @@ def set_ascend_forward_context( # the performance may degrade due to the switching of communication methods. mmrs_fusion = True if is_moe_model(vllm_config): - sp_enabled = enable_sp(vllm_config) and \ - tp_world_size > 1 and num_tokens is not None + sp_enabled = enable_sp(vllm_config) and num_tokens is not None mmrs_fusion = False else: sp_enabled = enable_sp(vllm_config) and \ - tp_world_size > 1 and \ num_tokens is not None and num_tokens > 1000 forward_context.mmrs_fusion = mmrs_fusion diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index e1afd24a..46e80606 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -659,6 +659,17 @@ def enable_sp(vllm_config=None) -> bool: # We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility. or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0')))) + if not _ENABLE_SP: + return _ENABLE_SP + + assert vllm_config.parallel_config.tensor_parallel_size > 1, \ + "Flash Comm v1 (Sequence Parallelism) is only supported when tp_size > 1." + + assert ( + not is_moe_model(vllm_config) + or vllm_config.parallel_config.enable_expert_parallel + ), "Flash Comm v1 (Sequence Parallelism) requires enable_expert_parallel=True for MoE models." + return _ENABLE_SP