diff --git a/vllm_ascend/patch/platform/__init__.py b/vllm_ascend/patch/platform/__init__.py index 99194f6..b4ef633 100644 --- a/vllm_ascend/patch/platform/__init__.py +++ b/vllm_ascend/patch/platform/__init__.py @@ -19,6 +19,7 @@ import os import vllm_ascend.patch.platform.patch_config # noqa import vllm_ascend.patch.platform.patch_distributed # noqa import vllm_ascend.patch.platform.patch_mamba_config # noqa +import vllm_ascend.patch.platform.patch_sched_yield # noqa if os.getenv("DYNAMIC_EPLB", "false") == "true" or os.getenv( "EXPERT_MAP_RECORD", "false") == "true": diff --git a/vllm_ascend/patch/platform/patch_sched_yield.py b/vllm_ascend/patch/platform/patch_sched_yield.py new file mode 100644 index 0000000..694b957 --- /dev/null +++ b/vllm_ascend/patch/platform/patch_sched_yield.py @@ -0,0 +1,13 @@ +import sys + +import vllm.distributed.utils +from vllm.platforms import CpuArchEnum, Platform + +is_arm = (Platform.get_cpu_architecture() == CpuArchEnum.ARM) + +USE_SCHED_YIELD = ( + ((sys.version_info[:3] >= (3, 11, 1)) or + (sys.version_info[:2] == (3, 10) and sys.version_info[2] >= 8)) + and not is_arm) + +vllm.distributed.utils.USE_SCHED_YIELD = USE_SCHED_YIELD diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index d47ef33..50088f3 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -21,6 +21,7 @@ if HAS_TRITON: import vllm_ascend.patch.worker.patch_triton # isort: off +import vllm_ascend.patch.platform.patch_sched_yield # noqa import vllm_ascend.patch.worker.patch_distributed # noqa import vllm_ascend.patch.worker.patch_logits # noqa import vllm_ascend.patch.worker.patch_roberta # noqa