Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -16,6 +16,7 @@ from vllm.distributed.device_communicators.pynccl_allocator import (
from vllm.logger import init_logger
from vllm.platforms import current_platform
from ..utils import StatelessProcessGroup
from .base_device_communicator import DeviceCommunicatorBase
import ixformer.distributed as ixfd
import os
@@ -29,8 +30,18 @@ class CudaCommunicator(DeviceCommunicatorBase):
device: torch.device | None = None,
device_group: ProcessGroup | None = None,
unique_name: str = "",
global_ranks: list[int] | None = None,
global_world_size: int | None = None,
tcp_store_group: StatelessProcessGroup | None = None,
):
super().__init__(cpu_group, device, device_group, unique_name)
super().__init__(
cpu_group,
device,
device_group,
unique_name,
global_ranks,
global_world_size,
)
if "tp" not in unique_name:
# custom allreduce or torch symm mem can be used only by tp
use_custom_allreduce = False
@@ -46,8 +57,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.use_custom_allreduce = use_custom_allreduce
self.use_torch_symm_mem = use_torch_symm_mem
self.use_flashinfer_allreduce = use_flashinfer_allreduce
self.use_vllm_comm = os.environ.get("VLLM_FORCE_NCCL_COMM",None) not in ["1", "Y", "y"]
# lazy import to avoid documentation build error
from vllm.distributed.device_communicators.custom_all_reduce import (
CustomAllreduce,
@@ -64,7 +76,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.pynccl_comm: PyNcclCommunicator | None = None
if self.world_size > 1:
self.pynccl_comm = PyNcclCommunicator(
group=self.cpu_group,
group=self.cpu_group if tcp_store_group is None else tcp_store_group,
device=self.device,
)
if is_symmetric_memory_enabled():
@@ -109,23 +121,27 @@ class CudaCommunicator(DeviceCommunicatorBase):
if self.all2all_backend == "naive":
from .all2all import NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
self.all2all_manager = NaiveAll2AllManager(
self.cpu_group, tcp_store_group
)
elif self.all2all_backend == "allgather_reducescatter":
from .all2all import AgRsAll2AllManager
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
elif self.all2all_backend == "pplx":
from .all2all import PPLXAll2AllManager
self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
self.all2all_manager = AgRsAll2AllManager(
self.cpu_group, tcp_store_group
)
elif self.all2all_backend == "deepep_high_throughput":
from .all2all import DeepEPHTAll2AllManager
self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group)
self.all2all_manager = DeepEPHTAll2AllManager(
self.cpu_group, tcp_store_group
)
elif self.all2all_backend == "deepep_low_latency":
from .all2all import DeepEPLLAll2AllManager
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
self.all2all_manager = DeepEPLLAll2AllManager(
self.cpu_group, tcp_store_group
)
elif self.all2all_backend == "mori":
from .all2all import MoriAll2AllManager
@@ -133,7 +149,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
elif self.all2all_backend == "flashinfer_all2allv":
from .all2all import FlashInferAllToAllManager
self.all2all_manager = FlashInferAllToAllManager(self.cpu_group)
self.all2all_manager = FlashInferAllToAllManager(
self.cpu_group, tcp_store_group
)
else:
raise ValueError(f"Unknown all2all backend: {self.all2all_backend}")
@@ -188,27 +206,12 @@ class CudaCommunicator(DeviceCommunicatorBase):
return out
if self.world_size == 1:
return input_
if self.use_vllm_comm:
# torch.ops.ixf_ops.vllm_all_reduce(input_, async_op=True)
ixfd.all_reduce(input_, group=self.device_group, async_op=True)
else:
torch.distributed.all_reduce(input_, group=self.device_group)
return input_
pynccl_comm = self.pynccl_comm
if pynccl_comm is None or pynccl_comm.disabled:
out = input_.clone()
torch.distributed.all_reduce(out, group=self.device_group)
return out
assert pynccl_comm is not None
out = pynccl_comm.all_reduce(input_)
if out is None:
# fall back to the default all-reduce using PyTorch.
# this usually happens during testing.
# when we run the model, allreduce only happens for the TP
# group, where we always have either custom allreduce or pynccl.
out = input_.clone()
torch.distributed.all_reduce(out, group=self.device_group)
return out
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
world_size = self.world_size
@@ -230,10 +233,8 @@ class CudaCommunicator(DeviceCommunicatorBase):
output_shape, dtype=input_tensor.dtype, device=input_tensor.device
)
# pynccl_comm.reduce_scatter(output, input_tensor)
torch.distributed.reduce_scatter_tensor(output,
input_tensor,
group=self.device_group)
# Perform reduce-scatter operation
ixfd.reduce_scatter_tensor(output,input_tensor,group=self.device_group, async_op=True)
# Reshape before returning
return output.movedim(0, dim).contiguous()
@@ -278,12 +279,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
"""NOTE: `dst` is the local rank of the destination rank."""
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
pynccl_comm = self.pynccl_comm
# if pynccl_comm is not None and not pynccl_comm.disabled:
# pynccl_comm.send(tensor, dst)
# else:
# torch.distributed.send(tensor, self.ranks[dst], self.device_group)
if self.use_vllm_comm:
ixfd.send(tensor, self.ranks[dst], self.device_group)
else:
@@ -298,17 +293,24 @@ class CudaCommunicator(DeviceCommunicatorBase):
src = (self.rank_in_group - 1) % self.world_size
tensor = torch.empty(size, dtype=dtype, device=self.device)
# pynccl_comm = self.pynccl_comm
# if pynccl_comm is not None and not pynccl_comm.disabled:
# pynccl_comm.recv(tensor, src)
# else:
# torch.distributed.recv(tensor, self.ranks[src], self.device_group)
if self.use_vllm_comm:
ixfd.recv(tensor, self.ranks[src], self.device_group)
else:
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor
def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
"""Broadcast a tensor from source rank to all ranks."""
if self.world_size == 1:
return tensor
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.broadcast(tensor, src)
return tensor
else:
raise ValueError("No PyNCCL communicator found")
def destroy(self):
if self.pynccl_comm is not None:
self.pynccl_comm = None
@@ -319,7 +321,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.fi_ar_comm = None
if self.all2all_manager is not None:
self.all2all_manager.destroy()
self.all2all_manager = None
self.all2all_manager = None # type: ignore[assignment]
def all_gatherv(
self,
@@ -372,7 +374,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
extra_residual:torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
@@ -409,16 +410,13 @@ class CudaCommunicator(DeviceCommunicatorBase):
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
# return self.all2all_manager.dispatch(
# hidden_states,
# topk_weights,
# topk_ids,
# is_sequence_parallel,
# extra_tensors=extra_tensors,
# )
hidden_states, extra_residual, router_logits = self.all2all_manager.dispatch(
hidden_states, extra_residual, router_logits)
return hidden_states, extra_residual, router_logits
return self.all2all_manager.dispatch(
hidden_states,
topk_weights,
topk_ids,
is_sequence_parallel,
extra_tensors=extra_tensors,
)
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
@@ -432,3 +430,10 @@ class CudaCommunicator(DeviceCommunicatorBase):
hidden_states,
is_sequence_parallel,
)
def batch_isend_irecv(self, p2p_ops: list):
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.batch_isend_irecv(p2p_ops)
else:
raise ValueError("No PyNCCL communicator found")