diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index e43bc0000..fdde7dde8 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -1055,6 +1055,11 @@ def init_distributed_environment( world_size=world_size, rank=rank, timeout=timeout, + device_id=torch.device( + f"cuda:{torch.cuda.current_device()}" + if hasattr(torch, "cuda") and torch.cuda.is_available() + else None + ), # Allow NCCL to eagerly init communicator ) # set the local rank