Clean up parallel_state.py (#11148)
This commit is contained in:
@@ -4,7 +4,7 @@
|
|||||||
# Adapted from
|
# Adapted from
|
||||||
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
|
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
|
||||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
"""vLLM distributed state.
|
"""Distributed state.
|
||||||
It takes over the control of the distributed environment from PyTorch.
|
It takes over the control of the distributed environment from PyTorch.
|
||||||
The typical workflow is:
|
The typical workflow is:
|
||||||
|
|
||||||
@@ -53,19 +53,26 @@ from sglang.srt.utils import (
|
|||||||
|
|
||||||
_is_npu = is_npu()
|
_is_npu = is_npu()
|
||||||
_is_cpu = is_cpu()
|
_is_cpu = is_cpu()
|
||||||
|
_supports_custom_op = supports_custom_op()
|
||||||
|
|
||||||
IS_ONE_DEVICE_PER_PROCESS = get_bool_env_var("SGLANG_ONE_DEVICE_PER_PROCESS")
|
IS_ONE_DEVICE_PER_PROCESS = get_bool_env_var("SGLANG_ONE_DEVICE_PER_PROCESS")
|
||||||
|
|
||||||
|
|
||||||
|
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
|
||||||
|
|
||||||
|
# use int value instead of ReduceOp.SUM to support torch compile
|
||||||
|
REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GraphCaptureContext:
|
class GraphCaptureContext:
|
||||||
stream: torch.cuda.Stream if not _is_npu else torch.npu.Stream
|
stream: torch.cuda.Stream if not _is_npu else torch.npu.Stream
|
||||||
|
|
||||||
|
|
||||||
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
|
@dataclass
|
||||||
|
class P2PWork:
|
||||||
# use int value instead of ReduceOp.SUM to support torch compile
|
work: Optional[torch.distributed.Work]
|
||||||
REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM)
|
payload: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
def _split_tensor_dict(
|
def _split_tensor_dict(
|
||||||
@@ -117,7 +124,7 @@ def _register_group(group: "GroupCoordinator") -> None:
|
|||||||
_groups[group.unique_name] = weakref.ref(group)
|
_groups[group.unique_name] = weakref.ref(group)
|
||||||
|
|
||||||
|
|
||||||
if supports_custom_op():
|
if _supports_custom_op:
|
||||||
|
|
||||||
def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
|
def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
|
||||||
assert group_name in _groups, f"Group {group_name} is not found."
|
assert group_name in _groups, f"Group {group_name} is not found."
|
||||||
@@ -277,7 +284,7 @@ class GroupCoordinator:
|
|||||||
self.use_npu_communicator = use_npu_communicator
|
self.use_npu_communicator = use_npu_communicator
|
||||||
self.use_message_queue_broadcaster = use_message_queue_broadcaster
|
self.use_message_queue_broadcaster = use_message_queue_broadcaster
|
||||||
|
|
||||||
# lazy import to avoid documentation build error
|
# Lazy import to avoid documentation build error
|
||||||
from sglang.srt.distributed.device_communicators.custom_all_reduce import (
|
from sglang.srt.distributed.device_communicators.custom_all_reduce import (
|
||||||
CustomAllreduce,
|
CustomAllreduce,
|
||||||
)
|
)
|
||||||
@@ -497,7 +504,7 @@ class GroupCoordinator:
|
|||||||
torch.distributed.all_reduce(input_, group=self.device_group)
|
torch.distributed.all_reduce(input_, group=self.device_group)
|
||||||
return input_
|
return input_
|
||||||
|
|
||||||
if not supports_custom_op():
|
if not _supports_custom_op:
|
||||||
self._all_reduce_in_place(input_)
|
self._all_reduce_in_place(input_)
|
||||||
return input_
|
return input_
|
||||||
|
|
||||||
@@ -523,23 +530,24 @@ class GroupCoordinator:
|
|||||||
|
|
||||||
outplace_all_reduce_method = None
|
outplace_all_reduce_method = None
|
||||||
if (
|
if (
|
||||||
self.qr_comm is not None
|
|
||||||
and not self.qr_comm.disabled
|
|
||||||
and self.qr_comm.should_quick_allreduce(input_)
|
|
||||||
):
|
|
||||||
outplace_all_reduce_method = "qr"
|
|
||||||
elif (
|
|
||||||
self.ca_comm is not None
|
self.ca_comm is not None
|
||||||
and not self.ca_comm.disabled
|
and not self.ca_comm.disabled
|
||||||
and self.ca_comm.should_custom_ar(input_)
|
and self.ca_comm.should_custom_ar(input_)
|
||||||
):
|
):
|
||||||
outplace_all_reduce_method = "ca"
|
outplace_all_reduce_method = "ca"
|
||||||
|
elif (
|
||||||
|
self.qr_comm is not None
|
||||||
|
and not self.qr_comm.disabled
|
||||||
|
and self.qr_comm.should_quick_allreduce(input_)
|
||||||
|
):
|
||||||
|
outplace_all_reduce_method = "qr"
|
||||||
elif (
|
elif (
|
||||||
self.pymscclpp_comm is not None
|
self.pymscclpp_comm is not None
|
||||||
and not self.pymscclpp_comm.disabled
|
and not self.pymscclpp_comm.disabled
|
||||||
and self.pymscclpp_comm.should_mscclpp_allreduce(input_)
|
and self.pymscclpp_comm.should_mscclpp_allreduce(input_)
|
||||||
):
|
):
|
||||||
outplace_all_reduce_method = "pymscclpp"
|
outplace_all_reduce_method = "pymscclpp"
|
||||||
|
|
||||||
if outplace_all_reduce_method is not None:
|
if outplace_all_reduce_method is not None:
|
||||||
return torch.ops.sglang.outplace_all_reduce(
|
return torch.ops.sglang.outplace_all_reduce(
|
||||||
input_,
|
input_,
|
||||||
@@ -553,16 +561,16 @@ class GroupCoordinator:
|
|||||||
def _all_reduce_out_place(
|
def _all_reduce_out_place(
|
||||||
self, input_: torch.Tensor, outplace_all_reduce_method: str
|
self, input_: torch.Tensor, outplace_all_reduce_method: str
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qr_comm = self.qr_comm
|
|
||||||
ca_comm = self.ca_comm
|
ca_comm = self.ca_comm
|
||||||
|
qr_comm = self.qr_comm
|
||||||
pymscclpp_comm = self.pymscclpp_comm
|
pymscclpp_comm = self.pymscclpp_comm
|
||||||
assert any([qr_comm, ca_comm, pymscclpp_comm])
|
assert any([qr_comm, ca_comm, pymscclpp_comm])
|
||||||
if outplace_all_reduce_method == "qr":
|
if outplace_all_reduce_method == "ca":
|
||||||
assert not qr_comm.disabled
|
|
||||||
out = qr_comm.quick_all_reduce(input_)
|
|
||||||
elif outplace_all_reduce_method == "ca":
|
|
||||||
assert not ca_comm.disabled
|
assert not ca_comm.disabled
|
||||||
out = ca_comm.custom_all_reduce(input_)
|
out = ca_comm.custom_all_reduce(input_)
|
||||||
|
elif outplace_all_reduce_method == "qr":
|
||||||
|
assert not qr_comm.disabled
|
||||||
|
out = qr_comm.quick_all_reduce(input_)
|
||||||
else:
|
else:
|
||||||
assert not pymscclpp_comm.disabled
|
assert not pymscclpp_comm.disabled
|
||||||
out = pymscclpp_comm.all_reduce(input_)
|
out = pymscclpp_comm.all_reduce(input_)
|
||||||
@@ -637,7 +645,7 @@ class GroupCoordinator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
|
def all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
|
||||||
if _is_npu or not supports_custom_op():
|
if _is_npu or not _supports_custom_op:
|
||||||
self._all_gather_into_tensor(output, input)
|
self._all_gather_into_tensor(output, input)
|
||||||
else:
|
else:
|
||||||
torch.ops.sglang.reg_all_gather_into_tensor(
|
torch.ops.sglang.reg_all_gather_into_tensor(
|
||||||
@@ -697,15 +705,13 @@ class GroupCoordinator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# All-gather.
|
# All-gather.
|
||||||
if input_.is_cpu and is_shm_available(
|
|
||||||
input_.dtype, self.world_size, self.local_size
|
|
||||||
):
|
|
||||||
return torch.ops.sgl_kernel.shm_allgather(input_, dim)
|
|
||||||
|
|
||||||
if input_.is_cpu:
|
if input_.is_cpu:
|
||||||
torch.distributed.all_gather_into_tensor(
|
if is_shm_available(input_.dtype, self.world_size, self.local_size):
|
||||||
output_tensor, input_, group=self.device_group
|
return torch.ops.sgl_kernel.shm_allgather(input_, dim)
|
||||||
)
|
else:
|
||||||
|
torch.distributed.all_gather_into_tensor(
|
||||||
|
output_tensor, input_, group=self.device_group
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.all_gather_into_tensor(output_tensor, input_)
|
self.all_gather_into_tensor(output_tensor, input_)
|
||||||
|
|
||||||
@@ -861,45 +867,63 @@ class GroupCoordinator:
|
|||||||
torch.distributed.all_gather_object(objs, obj, group=self.cpu_group)
|
torch.distributed.all_gather_object(objs, obj, group=self.cpu_group)
|
||||||
return objs
|
return objs
|
||||||
|
|
||||||
def send_object(self, obj: Any, dst: int) -> None:
|
def send_object(
|
||||||
"""Send the input object list to the destination rank."""
|
self,
|
||||||
"""NOTE: `dst` is the local rank of the destination rank."""
|
obj: Any,
|
||||||
|
dst: int,
|
||||||
|
async_send: bool = False,
|
||||||
|
) -> List[P2PWork]:
|
||||||
|
"""
|
||||||
|
Send the input object list to the destination rank.
|
||||||
|
This function uses the CPU group for all communications.
|
||||||
|
|
||||||
|
TODO: If you want to use GPU communication, please add a new argument (e.g., data_group, group),
|
||||||
|
use other functions (e.g., send), or implement a new function (e.g., send_object_device).
|
||||||
|
|
||||||
|
NOTE: `dst` is the local rank of the destination rank.
|
||||||
|
"""
|
||||||
|
|
||||||
assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
||||||
|
|
||||||
assert dst != self.rank_in_group, (
|
assert dst != self.rank_in_group, (
|
||||||
"Invalid destination rank. Destination rank is the same "
|
"Invalid destination rank. Destination rank is the same "
|
||||||
"as the current rank."
|
"as the current rank."
|
||||||
)
|
)
|
||||||
|
send_func = torch.distributed.isend if async_send else torch.distributed.send
|
||||||
|
|
||||||
# Serialize object to tensor and get the size as well
|
# Serialize object to tensor and get the size as well
|
||||||
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8).cuda(
|
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
|
||||||
device=torch.cuda.current_device()
|
|
||||||
)
|
|
||||||
|
|
||||||
size_tensor = torch.tensor(
|
size_tensor = torch.tensor(
|
||||||
[object_tensor.numel()],
|
[object_tensor.numel()], dtype=torch.long, device="cpu"
|
||||||
dtype=torch.long,
|
|
||||||
device="cpu",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send object size
|
# Send object size
|
||||||
torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group)
|
p2p_work = []
|
||||||
|
size_work = send_func(
|
||||||
# Send object
|
size_tensor,
|
||||||
torch.distributed.send(
|
self.ranks[dst],
|
||||||
object_tensor,
|
group=self.cpu_group,
|
||||||
dst=self.ranks[dst],
|
|
||||||
group=self.device_group,
|
|
||||||
)
|
)
|
||||||
|
if async_send:
|
||||||
|
p2p_work.append(P2PWork(size_work, size_tensor))
|
||||||
|
|
||||||
return None
|
object_work = send_func(
|
||||||
|
object_tensor,
|
||||||
|
self.ranks[dst],
|
||||||
|
group=self.cpu_group,
|
||||||
|
)
|
||||||
|
if async_send:
|
||||||
|
p2p_work.append(P2PWork(object_work, object_tensor))
|
||||||
|
|
||||||
def recv_object(self, src: int) -> Any:
|
return p2p_work
|
||||||
|
|
||||||
|
def recv_object(
|
||||||
|
self,
|
||||||
|
src: int,
|
||||||
|
) -> Any:
|
||||||
"""Receive the input object list from the source rank."""
|
"""Receive the input object list from the source rank."""
|
||||||
"""NOTE: `src` is the local rank of the source rank."""
|
"""NOTE: `src` is the local rank of the source rank."""
|
||||||
|
|
||||||
assert src < self.world_size, f"Invalid src rank ({src})"
|
assert src < self.world_size, f"Invalid src rank ({src})"
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
src != self.rank_in_group
|
src != self.rank_in_group
|
||||||
), "Invalid source rank. Source rank is the same as the current rank."
|
), "Invalid source rank. Source rank is the same as the current rank."
|
||||||
@@ -907,27 +931,25 @@ class GroupCoordinator:
|
|||||||
size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
|
size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
|
||||||
|
|
||||||
# Receive object size
|
# Receive object size
|
||||||
rank_size = torch.distributed.recv(
|
# We have to use irecv here to make it work for both isend and send.
|
||||||
|
work = torch.distributed.irecv(
|
||||||
size_tensor, src=self.ranks[src], group=self.cpu_group
|
size_tensor, src=self.ranks[src], group=self.cpu_group
|
||||||
)
|
)
|
||||||
|
work.wait()
|
||||||
|
|
||||||
# Tensor to receive serialized objects into.
|
# Tensor to receive serialized objects into.
|
||||||
object_tensor = torch.empty( # type: ignore[call-overload]
|
object_tensor: Any = torch.empty( # type: ignore[call-overload]
|
||||||
size_tensor.item(), # type: ignore[arg-type]
|
size_tensor.item(), # type: ignore[arg-type]
|
||||||
dtype=torch.uint8,
|
dtype=torch.uint8,
|
||||||
device=torch.cuda.current_device(),
|
device="cpu",
|
||||||
)
|
)
|
||||||
|
|
||||||
rank_object = torch.distributed.recv(
|
work = torch.distributed.irecv(
|
||||||
object_tensor, src=self.ranks[src], group=self.device_group
|
object_tensor, src=self.ranks[src], group=self.cpu_group
|
||||||
)
|
)
|
||||||
|
work.wait()
|
||||||
|
|
||||||
assert (
|
obj = pickle.loads(object_tensor.numpy())
|
||||||
rank_object == rank_size
|
|
||||||
), "Received object sender rank does not match the size sender rank."
|
|
||||||
|
|
||||||
obj = pickle.loads(object_tensor.cpu().numpy())
|
|
||||||
|
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
def broadcast_tensor_dict(
|
def broadcast_tensor_dict(
|
||||||
@@ -1017,12 +1039,13 @@ class GroupCoordinator:
|
|||||||
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
|
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
|
||||||
dst: Optional[int] = None,
|
dst: Optional[int] = None,
|
||||||
all_gather_group: Optional["GroupCoordinator"] = None,
|
all_gather_group: Optional["GroupCoordinator"] = None,
|
||||||
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
|
async_send: bool = False,
|
||||||
|
) -> Optional[List[P2PWork]]:
|
||||||
"""Send the input tensor dictionary.
|
"""Send the input tensor dictionary.
|
||||||
NOTE: `dst` is the local rank of the source rank.
|
NOTE: `dst` is the local rank of the source rank.
|
||||||
"""
|
"""
|
||||||
# Bypass the function if we are using only 1 GPU.
|
# Bypass the function if we are using only 1 GPU.
|
||||||
if not torch.distributed.is_initialized() or self.world_size == 1:
|
if self.world_size == 1:
|
||||||
return tensor_dict
|
return tensor_dict
|
||||||
|
|
||||||
all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
|
all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
|
||||||
@@ -1047,7 +1070,10 @@ class GroupCoordinator:
|
|||||||
# 1. Superior D2D transfer bandwidth
|
# 1. Superior D2D transfer bandwidth
|
||||||
# 2. Ability to overlap send and recv operations
|
# 2. Ability to overlap send and recv operations
|
||||||
# Thus the net performance gain justifies this approach.
|
# Thus the net performance gain justifies this approach.
|
||||||
self.send_object(metadata_list, dst=dst)
|
|
||||||
|
send_func = torch.distributed.isend if async_send else torch.distributed.send
|
||||||
|
p2p_works = self.send_object(metadata_list, dst=dst, async_send=async_send)
|
||||||
|
|
||||||
for tensor in tensor_list:
|
for tensor in tensor_list:
|
||||||
if tensor.numel() == 0:
|
if tensor.numel() == 0:
|
||||||
# Skip sending empty tensors.
|
# Skip sending empty tensors.
|
||||||
@@ -1057,15 +1083,10 @@ class GroupCoordinator:
|
|||||||
if all_gather_group is not None and tensor.numel() % all_gather_size == 0:
|
if all_gather_group is not None and tensor.numel() % all_gather_size == 0:
|
||||||
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
|
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
|
||||||
|
|
||||||
if tensor.is_cpu:
|
comm_group = metadata_group if tensor.is_cpu else group
|
||||||
# use metadata_group for CPU tensors
|
work = send_func(tensor, self.ranks[dst], group=comm_group)
|
||||||
torch.distributed.send(
|
p2p_works.append(P2PWork(work, tensor))
|
||||||
tensor, dst=self.ranks[dst], group=metadata_group
|
return p2p_works
|
||||||
)
|
|
||||||
else:
|
|
||||||
# use group for GPU tensors
|
|
||||||
torch.distributed.send(tensor, dst=self.ranks[dst], group=group)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def recv_tensor_dict(
|
def recv_tensor_dict(
|
||||||
self,
|
self,
|
||||||
@@ -1111,17 +1132,15 @@ class GroupCoordinator:
|
|||||||
orig_shape = tensor.shape
|
orig_shape = tensor.shape
|
||||||
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
|
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
|
||||||
|
|
||||||
if tensor.is_cpu:
|
# We have to use irecv here to make it work for both isend and send.
|
||||||
# use metadata_group for CPU tensors
|
comm_group = metadata_group if tensor.is_cpu else group
|
||||||
torch.distributed.recv(
|
work = torch.distributed.irecv(
|
||||||
tensor, src=self.ranks[src], group=metadata_group
|
tensor, src=self.ranks[src], group=comm_group
|
||||||
)
|
)
|
||||||
else:
|
work.wait()
|
||||||
# use group for GPU tensors
|
|
||||||
torch.distributed.recv(tensor, src=self.ranks[src], group=group)
|
|
||||||
if use_all_gather:
|
if use_all_gather:
|
||||||
# do the allgather
|
tensor = all_gather_group.all_gather(tensor, dim=0)
|
||||||
tensor = all_gather_group.all_gather(tensor, dim=0) # type: ignore
|
|
||||||
tensor = tensor.reshape(orig_shape)
|
tensor = tensor.reshape(orig_shape)
|
||||||
|
|
||||||
tensor_dict[key] = tensor
|
tensor_dict[key] = tensor
|
||||||
|
|||||||
Reference in New Issue
Block a user