From 111991fe2335fbfeb03330209bca1b051b11e69f Mon Sep 17 00:00:00 2001 From: ZhouXingg <165115237+ZhouXingg@users.noreply.github.com> Date: Wed, 12 Jun 2024 14:27:17 +0800 Subject: [PATCH] Fix Regression: Disable p2p for 4090 (#531) Co-authored-by: Qubitium <417764+Qubitium@users.noreply.github.com> --- python/sglang/srt/managers/controller/model_runner.py | 2 +- python/sglang/srt/utils.py | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/controller/model_runner.py b/python/sglang/srt/managers/controller/model_runner.py index 991700bc9..11c198be4 100644 --- a/python/sglang/srt/managers/controller/model_runner.py +++ b/python/sglang/srt/managers/controller/model_runner.py @@ -241,7 +241,7 @@ class ModelRunner: logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.") torch.cuda.set_device(self.gpu_id) logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.") - monkey_patch_vllm_p2p_access_check() + monkey_patch_vllm_p2p_access_check(self.gpu_id) init_distributed_environment( backend="nccl", world_size=self.tp_size, diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index bddb3ded5..43ae6f62a 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -453,14 +453,18 @@ def kill_parent_process(): os.kill(parent_process.pid, 9) -def monkey_patch_vllm_p2p_access_check(): +def monkey_patch_vllm_p2p_access_check(gpu_id: int): """ Monkey patch the slow p2p access check in vllm. NOTE: We assume the p2p access is always allowed, which can be wrong for some setups. """ - import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt - setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True) + # TODO: need a better check than just dev str name match + # compat: skip RTX 40 series as they do not have P2P feature and even checking for them may cause errors + device_name = torch.cuda.get_device_name(gpu_id) + if "RTX 40" not in device_name: + import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt + setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True) API_KEY_HEADER_NAME = "X-API-Key"