Clean up parallel_state.py (#11148)
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
# Adapted from
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
"""vLLM distributed state.
|
||||
"""Distributed state.
|
||||
It takes over the control of the distributed environment from PyTorch.
|
||||
The typical workflow is:
|
||||
|
||||
@@ -53,19 +53,26 @@ from sglang.srt.utils import (
|
||||
|
||||
_is_npu = is_npu()
|
||||
_is_cpu = is_cpu()
|
||||
_supports_custom_op = supports_custom_op()
|
||||
|
||||
IS_ONE_DEVICE_PER_PROCESS = get_bool_env_var("SGLANG_ONE_DEVICE_PER_PROCESS")
|
||||
|
||||
|
||||
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
|
||||
|
||||
# use int value instead of ReduceOp.SUM to support torch compile
|
||||
REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphCaptureContext:
|
||||
stream: torch.cuda.Stream if not _is_npu else torch.npu.Stream
|
||||
|
||||
|
||||
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
|
||||
|
||||
# use int value instead of ReduceOp.SUM to support torch compile
|
||||
REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM)
|
||||
@dataclass
|
||||
class P2PWork:
|
||||
work: Optional[torch.distributed.Work]
|
||||
payload: Optional[torch.Tensor]
|
||||
|
||||
|
||||
def _split_tensor_dict(
|
||||
@@ -117,7 +124,7 @@ def _register_group(group: "GroupCoordinator") -> None:
|
||||
_groups[group.unique_name] = weakref.ref(group)
|
||||
|
||||
|
||||
if supports_custom_op():
|
||||
if _supports_custom_op:
|
||||
|
||||
def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
|
||||
assert group_name in _groups, f"Group {group_name} is not found."
|
||||
@@ -277,7 +284,7 @@ class GroupCoordinator:
|
||||
self.use_npu_communicator = use_npu_communicator
|
||||
self.use_message_queue_broadcaster = use_message_queue_broadcaster
|
||||
|
||||
# lazy import to avoid documentation build error
|
||||
# Lazy import to avoid documentation build error
|
||||
from sglang.srt.distributed.device_communicators.custom_all_reduce import (
|
||||
CustomAllreduce,
|
||||
)
|
||||
@@ -497,7 +504,7 @@ class GroupCoordinator:
|
||||
torch.distributed.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
|
||||
if not supports_custom_op():
|
||||
if not _supports_custom_op:
|
||||
self._all_reduce_in_place(input_)
|
||||
return input_
|
||||
|
||||
@@ -523,23 +530,24 @@ class GroupCoordinator:
|
||||
|
||||
outplace_all_reduce_method = None
|
||||
if (
|
||||
self.qr_comm is not None
|
||||
and not self.qr_comm.disabled
|
||||
and self.qr_comm.should_quick_allreduce(input_)
|
||||
):
|
||||
outplace_all_reduce_method = "qr"
|
||||
elif (
|
||||
self.ca_comm is not None
|
||||
and not self.ca_comm.disabled
|
||||
and self.ca_comm.should_custom_ar(input_)
|
||||
):
|
||||
outplace_all_reduce_method = "ca"
|
||||
elif (
|
||||
self.qr_comm is not None
|
||||
and not self.qr_comm.disabled
|
||||
and self.qr_comm.should_quick_allreduce(input_)
|
||||
):
|
||||
outplace_all_reduce_method = "qr"
|
||||
elif (
|
||||
self.pymscclpp_comm is not None
|
||||
and not self.pymscclpp_comm.disabled
|
||||
and self.pymscclpp_comm.should_mscclpp_allreduce(input_)
|
||||
):
|
||||
outplace_all_reduce_method = "pymscclpp"
|
||||
|
||||
if outplace_all_reduce_method is not None:
|
||||
return torch.ops.sglang.outplace_all_reduce(
|
||||
input_,
|
||||
@@ -553,16 +561,16 @@ class GroupCoordinator:
|
||||
def _all_reduce_out_place(
|
||||
self, input_: torch.Tensor, outplace_all_reduce_method: str
|
||||
) -> torch.Tensor:
|
||||
qr_comm = self.qr_comm
|
||||
ca_comm = self.ca_comm
|
||||
qr_comm = self.qr_comm
|
||||
pymscclpp_comm = self.pymscclpp_comm
|
||||
assert any([qr_comm, ca_comm, pymscclpp_comm])
|
||||
if outplace_all_reduce_method == "qr":
|
||||
assert not qr_comm.disabled
|
||||
out = qr_comm.quick_all_reduce(input_)
|
||||
elif outplace_all_reduce_method == "ca":
|
||||
if outplace_all_reduce_method == "ca":
|
||||
assert not ca_comm.disabled
|
||||
out = ca_comm.custom_all_reduce(input_)
|
||||
elif outplace_all_reduce_method == "qr":
|
||||
assert not qr_comm.disabled
|
||||
out = qr_comm.quick_all_reduce(input_)
|
||||
else:
|
||||
assert not pymscclpp_comm.disabled
|
||||
out = pymscclpp_comm.all_reduce(input_)
|
||||
@@ -637,7 +645,7 @@ class GroupCoordinator:
|
||||
)
|
||||
|
||||
def all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
|
||||
if _is_npu or not supports_custom_op():
|
||||
if _is_npu or not _supports_custom_op:
|
||||
self._all_gather_into_tensor(output, input)
|
||||
else:
|
||||
torch.ops.sglang.reg_all_gather_into_tensor(
|
||||
@@ -697,15 +705,13 @@ class GroupCoordinator:
|
||||
)
|
||||
|
||||
# All-gather.
|
||||
if input_.is_cpu and is_shm_available(
|
||||
input_.dtype, self.world_size, self.local_size
|
||||
):
|
||||
return torch.ops.sgl_kernel.shm_allgather(input_, dim)
|
||||
|
||||
if input_.is_cpu:
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
output_tensor, input_, group=self.device_group
|
||||
)
|
||||
if is_shm_available(input_.dtype, self.world_size, self.local_size):
|
||||
return torch.ops.sgl_kernel.shm_allgather(input_, dim)
|
||||
else:
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
output_tensor, input_, group=self.device_group
|
||||
)
|
||||
else:
|
||||
self.all_gather_into_tensor(output_tensor, input_)
|
||||
|
||||
@@ -861,45 +867,63 @@ class GroupCoordinator:
|
||||
torch.distributed.all_gather_object(objs, obj, group=self.cpu_group)
|
||||
return objs
|
||||
|
||||
def send_object(self, obj: Any, dst: int) -> None:
|
||||
"""Send the input object list to the destination rank."""
|
||||
"""NOTE: `dst` is the local rank of the destination rank."""
|
||||
def send_object(
|
||||
self,
|
||||
obj: Any,
|
||||
dst: int,
|
||||
async_send: bool = False,
|
||||
) -> List[P2PWork]:
|
||||
"""
|
||||
Send the input object list to the destination rank.
|
||||
This function uses the CPU group for all communications.
|
||||
|
||||
TODO: If you want to use GPU communication, please add a new argument (e.g., data_group, group),
|
||||
use other functions (e.g., send), or implement a new function (e.g., send_object_device).
|
||||
|
||||
NOTE: `dst` is the local rank of the destination rank.
|
||||
"""
|
||||
|
||||
assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
||||
|
||||
assert dst != self.rank_in_group, (
|
||||
"Invalid destination rank. Destination rank is the same "
|
||||
"as the current rank."
|
||||
)
|
||||
send_func = torch.distributed.isend if async_send else torch.distributed.send
|
||||
|
||||
# Serialize object to tensor and get the size as well
|
||||
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8).cuda(
|
||||
device=torch.cuda.current_device()
|
||||
)
|
||||
|
||||
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
|
||||
size_tensor = torch.tensor(
|
||||
[object_tensor.numel()],
|
||||
dtype=torch.long,
|
||||
device="cpu",
|
||||
[object_tensor.numel()], dtype=torch.long, device="cpu"
|
||||
)
|
||||
|
||||
# Send object size
|
||||
torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group)
|
||||
|
||||
# Send object
|
||||
torch.distributed.send(
|
||||
object_tensor,
|
||||
dst=self.ranks[dst],
|
||||
group=self.device_group,
|
||||
p2p_work = []
|
||||
size_work = send_func(
|
||||
size_tensor,
|
||||
self.ranks[dst],
|
||||
group=self.cpu_group,
|
||||
)
|
||||
if async_send:
|
||||
p2p_work.append(P2PWork(size_work, size_tensor))
|
||||
|
||||
return None
|
||||
object_work = send_func(
|
||||
object_tensor,
|
||||
self.ranks[dst],
|
||||
group=self.cpu_group,
|
||||
)
|
||||
if async_send:
|
||||
p2p_work.append(P2PWork(object_work, object_tensor))
|
||||
|
||||
def recv_object(self, src: int) -> Any:
|
||||
return p2p_work
|
||||
|
||||
def recv_object(
|
||||
self,
|
||||
src: int,
|
||||
) -> Any:
|
||||
"""Receive the input object list from the source rank."""
|
||||
"""NOTE: `src` is the local rank of the source rank."""
|
||||
|
||||
assert src < self.world_size, f"Invalid src rank ({src})"
|
||||
|
||||
assert (
|
||||
src != self.rank_in_group
|
||||
), "Invalid source rank. Source rank is the same as the current rank."
|
||||
@@ -907,27 +931,25 @@ class GroupCoordinator:
|
||||
size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
|
||||
|
||||
# Receive object size
|
||||
rank_size = torch.distributed.recv(
|
||||
# We have to use irecv here to make it work for both isend and send.
|
||||
work = torch.distributed.irecv(
|
||||
size_tensor, src=self.ranks[src], group=self.cpu_group
|
||||
)
|
||||
work.wait()
|
||||
|
||||
# Tensor to receive serialized objects into.
|
||||
object_tensor = torch.empty( # type: ignore[call-overload]
|
||||
object_tensor: Any = torch.empty( # type: ignore[call-overload]
|
||||
size_tensor.item(), # type: ignore[arg-type]
|
||||
dtype=torch.uint8,
|
||||
device=torch.cuda.current_device(),
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
rank_object = torch.distributed.recv(
|
||||
object_tensor, src=self.ranks[src], group=self.device_group
|
||||
work = torch.distributed.irecv(
|
||||
object_tensor, src=self.ranks[src], group=self.cpu_group
|
||||
)
|
||||
work.wait()
|
||||
|
||||
assert (
|
||||
rank_object == rank_size
|
||||
), "Received object sender rank does not match the size sender rank."
|
||||
|
||||
obj = pickle.loads(object_tensor.cpu().numpy())
|
||||
|
||||
obj = pickle.loads(object_tensor.numpy())
|
||||
return obj
|
||||
|
||||
def broadcast_tensor_dict(
|
||||
@@ -1017,12 +1039,13 @@ class GroupCoordinator:
|
||||
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
|
||||
dst: Optional[int] = None,
|
||||
all_gather_group: Optional["GroupCoordinator"] = None,
|
||||
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
|
||||
async_send: bool = False,
|
||||
) -> Optional[List[P2PWork]]:
|
||||
"""Send the input tensor dictionary.
|
||||
NOTE: `dst` is the local rank of the source rank.
|
||||
"""
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if not torch.distributed.is_initialized() or self.world_size == 1:
|
||||
if self.world_size == 1:
|
||||
return tensor_dict
|
||||
|
||||
all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
|
||||
@@ -1047,7 +1070,10 @@ class GroupCoordinator:
|
||||
# 1. Superior D2D transfer bandwidth
|
||||
# 2. Ability to overlap send and recv operations
|
||||
# Thus the net performance gain justifies this approach.
|
||||
self.send_object(metadata_list, dst=dst)
|
||||
|
||||
send_func = torch.distributed.isend if async_send else torch.distributed.send
|
||||
p2p_works = self.send_object(metadata_list, dst=dst, async_send=async_send)
|
||||
|
||||
for tensor in tensor_list:
|
||||
if tensor.numel() == 0:
|
||||
# Skip sending empty tensors.
|
||||
@@ -1057,15 +1083,10 @@ class GroupCoordinator:
|
||||
if all_gather_group is not None and tensor.numel() % all_gather_size == 0:
|
||||
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
|
||||
|
||||
if tensor.is_cpu:
|
||||
# use metadata_group for CPU tensors
|
||||
torch.distributed.send(
|
||||
tensor, dst=self.ranks[dst], group=metadata_group
|
||||
)
|
||||
else:
|
||||
# use group for GPU tensors
|
||||
torch.distributed.send(tensor, dst=self.ranks[dst], group=group)
|
||||
return None
|
||||
comm_group = metadata_group if tensor.is_cpu else group
|
||||
work = send_func(tensor, self.ranks[dst], group=comm_group)
|
||||
p2p_works.append(P2PWork(work, tensor))
|
||||
return p2p_works
|
||||
|
||||
def recv_tensor_dict(
|
||||
self,
|
||||
@@ -1111,17 +1132,15 @@ class GroupCoordinator:
|
||||
orig_shape = tensor.shape
|
||||
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
|
||||
|
||||
if tensor.is_cpu:
|
||||
# use metadata_group for CPU tensors
|
||||
torch.distributed.recv(
|
||||
tensor, src=self.ranks[src], group=metadata_group
|
||||
)
|
||||
else:
|
||||
# use group for GPU tensors
|
||||
torch.distributed.recv(tensor, src=self.ranks[src], group=group)
|
||||
# We have to use irecv here to make it work for both isend and send.
|
||||
comm_group = metadata_group if tensor.is_cpu else group
|
||||
work = torch.distributed.irecv(
|
||||
tensor, src=self.ranks[src], group=comm_group
|
||||
)
|
||||
work.wait()
|
||||
|
||||
if use_all_gather:
|
||||
# do the allgather
|
||||
tensor = all_gather_group.all_gather(tensor, dim=0) # type: ignore
|
||||
tensor = all_gather_group.all_gather(tensor, dim=0)
|
||||
tensor = tensor.reshape(orig_shape)
|
||||
|
||||
tensor_dict[key] = tensor
|
||||
|
||||
Reference in New Issue
Block a user