diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 57d966f70..dc120f761 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -43,6 +43,7 @@ from sglang.srt.utils import ( direct_register_custom_op, get_bool_env_var, get_int_env_var, + is_cpu, is_cuda_alike, is_hip, is_npu, @@ -51,6 +52,7 @@ from sglang.srt.utils import ( ) _is_npu = is_npu() +_is_cpu = is_cpu() IS_ONE_DEVICE_PER_PROCESS = get_bool_env_var("SGLANG_ONE_DEVICE_PER_PROCESS") @@ -1643,7 +1645,7 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False): ray.shutdown() gc.collect() - if not current_platform.is_cpu(): + if not _is_cpu: if hasattr(torch, "cuda") and torch.cuda.is_available(): torch.cuda.empty_cache() if hasattr(torch._C, "_host_emptyCache"):