Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -16,6 +16,7 @@ from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..utils import StatelessProcessGroup
|
||||
from .base_device_communicator import DeviceCommunicatorBase
|
||||
import ixformer.distributed as ixfd
|
||||
import os
|
||||
@@ -29,8 +30,18 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
device: torch.device | None = None,
|
||||
device_group: ProcessGroup | None = None,
|
||||
unique_name: str = "",
|
||||
global_ranks: list[int] | None = None,
|
||||
global_world_size: int | None = None,
|
||||
tcp_store_group: StatelessProcessGroup | None = None,
|
||||
):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
super().__init__(
|
||||
cpu_group,
|
||||
device,
|
||||
device_group,
|
||||
unique_name,
|
||||
global_ranks,
|
||||
global_world_size,
|
||||
)
|
||||
if "tp" not in unique_name:
|
||||
# custom allreduce or torch symm mem can be used only by tp
|
||||
use_custom_allreduce = False
|
||||
@@ -46,8 +57,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
self.use_custom_allreduce = use_custom_allreduce
|
||||
self.use_torch_symm_mem = use_torch_symm_mem
|
||||
self.use_flashinfer_allreduce = use_flashinfer_allreduce
|
||||
|
||||
|
||||
self.use_vllm_comm = os.environ.get("VLLM_FORCE_NCCL_COMM",None) not in ["1", "Y", "y"]
|
||||
|
||||
# lazy import to avoid documentation build error
|
||||
from vllm.distributed.device_communicators.custom_all_reduce import (
|
||||
CustomAllreduce,
|
||||
@@ -64,7 +76,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
self.pynccl_comm: PyNcclCommunicator | None = None
|
||||
if self.world_size > 1:
|
||||
self.pynccl_comm = PyNcclCommunicator(
|
||||
group=self.cpu_group,
|
||||
group=self.cpu_group if tcp_store_group is None else tcp_store_group,
|
||||
device=self.device,
|
||||
)
|
||||
if is_symmetric_memory_enabled():
|
||||
@@ -109,23 +121,27 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
if self.all2all_backend == "naive":
|
||||
from .all2all import NaiveAll2AllManager
|
||||
|
||||
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
|
||||
self.all2all_manager = NaiveAll2AllManager(
|
||||
self.cpu_group, tcp_store_group
|
||||
)
|
||||
elif self.all2all_backend == "allgather_reducescatter":
|
||||
from .all2all import AgRsAll2AllManager
|
||||
|
||||
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
|
||||
elif self.all2all_backend == "pplx":
|
||||
from .all2all import PPLXAll2AllManager
|
||||
|
||||
self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
|
||||
self.all2all_manager = AgRsAll2AllManager(
|
||||
self.cpu_group, tcp_store_group
|
||||
)
|
||||
elif self.all2all_backend == "deepep_high_throughput":
|
||||
from .all2all import DeepEPHTAll2AllManager
|
||||
|
||||
self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group)
|
||||
self.all2all_manager = DeepEPHTAll2AllManager(
|
||||
self.cpu_group, tcp_store_group
|
||||
)
|
||||
elif self.all2all_backend == "deepep_low_latency":
|
||||
from .all2all import DeepEPLLAll2AllManager
|
||||
|
||||
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
|
||||
self.all2all_manager = DeepEPLLAll2AllManager(
|
||||
self.cpu_group, tcp_store_group
|
||||
)
|
||||
elif self.all2all_backend == "mori":
|
||||
from .all2all import MoriAll2AllManager
|
||||
|
||||
@@ -133,7 +149,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
elif self.all2all_backend == "flashinfer_all2allv":
|
||||
from .all2all import FlashInferAllToAllManager
|
||||
|
||||
self.all2all_manager = FlashInferAllToAllManager(self.cpu_group)
|
||||
self.all2all_manager = FlashInferAllToAllManager(
|
||||
self.cpu_group, tcp_store_group
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown all2all backend: {self.all2all_backend}")
|
||||
|
||||
@@ -188,27 +206,12 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
return out
|
||||
if self.world_size == 1:
|
||||
return input_
|
||||
|
||||
if self.use_vllm_comm:
|
||||
# torch.ops.ixf_ops.vllm_all_reduce(input_, async_op=True)
|
||||
ixfd.all_reduce(input_, group=self.device_group, async_op=True)
|
||||
else:
|
||||
torch.distributed.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
pynccl_comm = self.pynccl_comm
|
||||
if pynccl_comm is None or pynccl_comm.disabled:
|
||||
out = input_.clone()
|
||||
torch.distributed.all_reduce(out, group=self.device_group)
|
||||
return out
|
||||
assert pynccl_comm is not None
|
||||
out = pynccl_comm.all_reduce(input_)
|
||||
if out is None:
|
||||
# fall back to the default all-reduce using PyTorch.
|
||||
# this usually happens during testing.
|
||||
# when we run the model, allreduce only happens for the TP
|
||||
# group, where we always have either custom allreduce or pynccl.
|
||||
out = input_.clone()
|
||||
torch.distributed.all_reduce(out, group=self.device_group)
|
||||
return out
|
||||
|
||||
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
|
||||
world_size = self.world_size
|
||||
@@ -230,10 +233,8 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
output_shape, dtype=input_tensor.dtype, device=input_tensor.device
|
||||
)
|
||||
|
||||
# pynccl_comm.reduce_scatter(output, input_tensor)
|
||||
torch.distributed.reduce_scatter_tensor(output,
|
||||
input_tensor,
|
||||
group=self.device_group)
|
||||
# Perform reduce-scatter operation
|
||||
ixfd.reduce_scatter_tensor(output,input_tensor,group=self.device_group, async_op=True)
|
||||
|
||||
# Reshape before returning
|
||||
return output.movedim(0, dim).contiguous()
|
||||
@@ -278,12 +279,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
"""NOTE: `dst` is the local rank of the destination rank."""
|
||||
if dst is None:
|
||||
dst = (self.rank_in_group + 1) % self.world_size
|
||||
|
||||
pynccl_comm = self.pynccl_comm
|
||||
# if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||
# pynccl_comm.send(tensor, dst)
|
||||
# else:
|
||||
# torch.distributed.send(tensor, self.ranks[dst], self.device_group)
|
||||
if self.use_vllm_comm:
|
||||
ixfd.send(tensor, self.ranks[dst], self.device_group)
|
||||
else:
|
||||
@@ -298,17 +293,24 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
src = (self.rank_in_group - 1) % self.world_size
|
||||
|
||||
tensor = torch.empty(size, dtype=dtype, device=self.device)
|
||||
# pynccl_comm = self.pynccl_comm
|
||||
# if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||
# pynccl_comm.recv(tensor, src)
|
||||
# else:
|
||||
# torch.distributed.recv(tensor, self.ranks[src], self.device_group)
|
||||
if self.use_vllm_comm:
|
||||
ixfd.recv(tensor, self.ranks[src], self.device_group)
|
||||
else:
|
||||
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
|
||||
return tensor
|
||||
|
||||
def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
|
||||
"""Broadcast a tensor from source rank to all ranks."""
|
||||
if self.world_size == 1:
|
||||
return tensor
|
||||
|
||||
pynccl_comm = self.pynccl_comm
|
||||
if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||
pynccl_comm.broadcast(tensor, src)
|
||||
return tensor
|
||||
else:
|
||||
raise ValueError("No PyNCCL communicator found")
|
||||
|
||||
def destroy(self):
|
||||
if self.pynccl_comm is not None:
|
||||
self.pynccl_comm = None
|
||||
@@ -319,7 +321,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
self.fi_ar_comm = None
|
||||
if self.all2all_manager is not None:
|
||||
self.all2all_manager.destroy()
|
||||
self.all2all_manager = None
|
||||
self.all2all_manager = None # type: ignore[assignment]
|
||||
|
||||
def all_gatherv(
|
||||
self,
|
||||
@@ -372,7 +374,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
extra_residual:torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
@@ -409,16 +410,13 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
assert self.all2all_manager is not None
|
||||
# return self.all2all_manager.dispatch(
|
||||
# hidden_states,
|
||||
# topk_weights,
|
||||
# topk_ids,
|
||||
# is_sequence_parallel,
|
||||
# extra_tensors=extra_tensors,
|
||||
# )
|
||||
hidden_states, extra_residual, router_logits = self.all2all_manager.dispatch(
|
||||
hidden_states, extra_residual, router_logits)
|
||||
return hidden_states, extra_residual, router_logits
|
||||
return self.all2all_manager.dispatch(
|
||||
hidden_states,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
is_sequence_parallel,
|
||||
extra_tensors=extra_tensors,
|
||||
)
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
@@ -432,3 +430,10 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
hidden_states,
|
||||
is_sequence_parallel,
|
||||
)
|
||||
|
||||
def batch_isend_irecv(self, p2p_ops: list):
|
||||
pynccl_comm = self.pynccl_comm
|
||||
if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||
pynccl_comm.batch_isend_irecv(p2p_ops)
|
||||
else:
|
||||
raise ValueError("No PyNCCL communicator found")
|
||||
|
||||
Reference in New Issue
Block a user