diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 47cc2e9..85348db 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -114,7 +114,7 @@ def set_ascend_forward_context( # the performance may degrade due to the switching of communication methods. if is_moe_model(vllm_config): sp_enabled = enable_sp(vllm_config) and \ - tp_world_size > 1 + tp_world_size > 1 and num_tokens is not None else: sp_enabled = enable_sp(vllm_config) and \ tp_world_size > 1 and \ diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 7bc79d8..0929e40 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -56,6 +56,7 @@ _ASCEND_CUSTOMOP_IS_REIGISTERED = False _DEFAULT_BUFFER_SIZE = 200 _MIN_DP_BUFFER_SIZE = 50 _IS_MOE_MODEL = None +_ENABLE_SP = None def is_310p(): @@ -606,15 +607,20 @@ def dense_optim_enable() -> bool: def enable_sp(vllm_config=None) -> bool: - if vllm_config is None: - from vllm.config import get_current_vllm_config - vllm_config = get_current_vllm_config() - return ( - vllm_config.compilation_config.pass_config.enable_sequence_parallelism - or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1 - # Flash comm 1 should be enabled by env VLLM_ASCEND_ENABLE_FLASHCOMM1 - # We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility. - or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0')))) + global _ENABLE_SP + if _ENABLE_SP is None: + if vllm_config is None: + from vllm.config import get_current_vllm_config + vllm_config = get_current_vllm_config() + _ENABLE_SP = ( + vllm_config.compilation_config.pass_config. + enable_sequence_parallelism + or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1 + # Flash comm 1 should be enabled by env VLLM_ASCEND_ENABLE_FLASHCOMM1 + # We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility. + or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0')))) + + return _ENABLE_SP # TODO remove it after vllm has this func diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 50696b4..8c61d04 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -815,7 +815,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): # Create a tensor for num_tokens_after_padding num_tokens_after_padding = torch.tensor([max_tokens_across_dp] * self.dp_size, - device="npu", + device="cpu", dtype=torch.int32) return max_tokens_across_dp, num_tokens_after_padding, global_with_prefill, global_enable_dbo