[Auto Sync] Update parallel_state.py (20250830) (#9828)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -52,6 +52,8 @@ from sglang.srt.utils import (
|
|||||||
|
|
||||||
_is_npu = is_npu()
|
_is_npu = is_npu()
|
||||||
|
|
||||||
|
IS_ONE_DEVICE_PER_PROCESS = get_bool_env_var("SGLANG_ONE_DEVICE_PER_PROCESS")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GraphCaptureContext:
|
class GraphCaptureContext:
|
||||||
@@ -223,10 +225,12 @@ class GroupCoordinator:
|
|||||||
use_message_queue_broadcaster: bool = False,
|
use_message_queue_broadcaster: bool = False,
|
||||||
group_name: Optional[str] = None,
|
group_name: Optional[str] = None,
|
||||||
):
|
):
|
||||||
|
# Set group info
|
||||||
group_name = group_name or "anonymous"
|
group_name = group_name or "anonymous"
|
||||||
self.unique_name = _get_unique_name(group_name)
|
self.unique_name = _get_unique_name(group_name)
|
||||||
_register_group(self)
|
_register_group(self)
|
||||||
|
|
||||||
|
# Set rank info
|
||||||
self.rank = torch.distributed.get_rank()
|
self.rank = torch.distributed.get_rank()
|
||||||
self.local_rank = local_rank
|
self.local_rank = local_rank
|
||||||
self.device_group = None
|
self.device_group = None
|
||||||
@@ -250,14 +254,16 @@ class GroupCoordinator:
|
|||||||
assert self.cpu_group is not None
|
assert self.cpu_group is not None
|
||||||
assert self.device_group is not None
|
assert self.device_group is not None
|
||||||
|
|
||||||
|
device_id = 0 if IS_ONE_DEVICE_PER_PROCESS else local_rank
|
||||||
if is_cuda_alike():
|
if is_cuda_alike():
|
||||||
self.device = torch.device(f"cuda:{local_rank}")
|
self.device = torch.device(f"cuda:{device_id}")
|
||||||
elif _is_npu:
|
elif _is_npu:
|
||||||
self.device = torch.device(f"npu:{local_rank}")
|
self.device = torch.device(f"npu:{device_id}")
|
||||||
else:
|
else:
|
||||||
self.device = torch.device("cpu")
|
self.device = torch.device("cpu")
|
||||||
self.device_module = torch.get_device_module(self.device)
|
self.device_module = torch.get_device_module(self.device)
|
||||||
|
|
||||||
|
# Import communicators
|
||||||
self.use_pynccl = use_pynccl
|
self.use_pynccl = use_pynccl
|
||||||
self.use_pymscclpp = use_pymscclpp
|
self.use_pymscclpp = use_pymscclpp
|
||||||
self.use_custom_allreduce = use_custom_allreduce
|
self.use_custom_allreduce = use_custom_allreduce
|
||||||
@@ -270,6 +276,9 @@ class GroupCoordinator:
|
|||||||
from sglang.srt.distributed.device_communicators.custom_all_reduce import (
|
from sglang.srt.distributed.device_communicators.custom_all_reduce import (
|
||||||
CustomAllreduce,
|
CustomAllreduce,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.distributed.device_communicators.pymscclpp import (
|
||||||
|
PyMscclppCommunicator,
|
||||||
|
)
|
||||||
from sglang.srt.distributed.device_communicators.pynccl import (
|
from sglang.srt.distributed.device_communicators.pynccl import (
|
||||||
PyNcclCommunicator,
|
PyNcclCommunicator,
|
||||||
)
|
)
|
||||||
@@ -287,10 +296,6 @@ class GroupCoordinator:
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
from sglang.srt.distributed.device_communicators.pymscclpp import (
|
|
||||||
PyMscclppCommunicator,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.pymscclpp_comm: Optional[PyMscclppCommunicator] = None
|
self.pymscclpp_comm: Optional[PyMscclppCommunicator] = None
|
||||||
if use_pymscclpp and self.world_size > 1:
|
if use_pymscclpp and self.world_size > 1:
|
||||||
self.pymscclpp_comm = PyMscclppCommunicator(
|
self.pymscclpp_comm = PyMscclppCommunicator(
|
||||||
@@ -325,30 +330,30 @@ class GroupCoordinator:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to initialize QuickAllReduce: {e}")
|
logger.warning(f"Failed to initialize QuickAllReduce: {e}")
|
||||||
|
|
||||||
|
# Create communicator for other hardware backends
|
||||||
from sglang.srt.distributed.device_communicators.hpu_communicator import (
|
from sglang.srt.distributed.device_communicators.hpu_communicator import (
|
||||||
HpuCommunicator,
|
HpuCommunicator,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.distributed.device_communicators.npu_communicator import (
|
||||||
|
NpuCommunicator,
|
||||||
|
)
|
||||||
|
from sglang.srt.distributed.device_communicators.xpu_communicator import (
|
||||||
|
XpuCommunicator,
|
||||||
|
)
|
||||||
|
|
||||||
self.hpu_communicator: Optional[HpuCommunicator] = None
|
self.hpu_communicator: Optional[HpuCommunicator] = None
|
||||||
if use_hpu_communicator and self.world_size > 1:
|
if use_hpu_communicator and self.world_size > 1:
|
||||||
self.hpu_communicator = HpuCommunicator(group=self.device_group)
|
self.hpu_communicator = HpuCommunicator(group=self.device_group)
|
||||||
|
|
||||||
from sglang.srt.distributed.device_communicators.xpu_communicator import (
|
|
||||||
XpuCommunicator,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.xpu_communicator: Optional[XpuCommunicator] = None
|
self.xpu_communicator: Optional[XpuCommunicator] = None
|
||||||
if use_xpu_communicator and self.world_size > 1:
|
if use_xpu_communicator and self.world_size > 1:
|
||||||
self.xpu_communicator = XpuCommunicator(group=self.device_group)
|
self.xpu_communicator = XpuCommunicator(group=self.device_group)
|
||||||
|
|
||||||
from sglang.srt.distributed.device_communicators.npu_communicator import (
|
|
||||||
NpuCommunicator,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.npu_communicator: Optional[NpuCommunicator] = None
|
self.npu_communicator: Optional[NpuCommunicator] = None
|
||||||
if use_npu_communicator and self.world_size > 1:
|
if use_npu_communicator and self.world_size > 1:
|
||||||
self.npu_communicator = NpuCommunicator(group=self.device_group)
|
self.npu_communicator = NpuCommunicator(group=self.device_group)
|
||||||
|
|
||||||
|
# Create message queue
|
||||||
from sglang.srt.distributed.device_communicators.shm_broadcast import (
|
from sglang.srt.distributed.device_communicators.shm_broadcast import (
|
||||||
MessageQueue,
|
MessageQueue,
|
||||||
)
|
)
|
||||||
@@ -848,6 +853,11 @@ class GroupCoordinator:
|
|||||||
)
|
)
|
||||||
return obj_list
|
return obj_list
|
||||||
|
|
||||||
|
def all_gather_object(self, obj: Any) -> List[Any]:
|
||||||
|
objs = [None] * self.world_size
|
||||||
|
torch.distributed.all_gather_object(objs, obj, group=self.cpu_group)
|
||||||
|
return objs
|
||||||
|
|
||||||
def send_object(self, obj: Any, dst: int) -> None:
|
def send_object(self, obj: Any, dst: int) -> None:
|
||||||
"""Send the input object list to the destination rank."""
|
"""Send the input object list to the destination rank."""
|
||||||
"""NOTE: `dst` is the local rank of the destination rank."""
|
"""NOTE: `dst` is the local rank of the destination rank."""
|
||||||
|
|||||||
Reference in New Issue
Block a user