diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 8888b70..607f029 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -106,7 +106,7 @@ def set_ascend_forward_context( # Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold, # the performance benefits can be maximized. Conversely, if the concurrency is below the threshold, # the performance may degrade due to the switching of communication methods. - sp_enabled = enable_sp() and \ + sp_enabled = enable_sp(vllm_config) and \ tp_world_size > 1 and \ num_tokens is not None and num_tokens > 1000 diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 6107e25..805fd57 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -597,11 +597,12 @@ def dense_optim_enable() -> bool: return envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE -def enable_sp() -> bool: - from vllm.config import get_cached_compilation_config - +def enable_sp(vllm_config=None) -> bool: + if vllm_config is None: + from vllm.config import get_current_vllm_config + vllm_config = get_current_vllm_config() return ( - get_cached_compilation_config().pass_config.enable_sequence_parallelism + vllm_config.compilation_config.pass_config.enable_sequence_parallelism or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM)