Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -3,14 +3,13 @@
from typing import Any
import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm.distributed import get_dp_group, get_ep_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.utils.flashinfer import has_flashinfer_all2all
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx
from vllm.utils.import_utils import has_deep_ep, has_mori
from .base_device_communicator import All2AllManagerBase, Cache
@@ -32,8 +31,8 @@ class NaiveAll2AllManager(All2AllManagerBase):
debugging.
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)
def naive_multicast(
self,
@@ -139,8 +138,8 @@ class AgRsAll2AllManager(All2AllManagerBase):
all-gather (dispatch) and reduce-scatter (combine).
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)
def dispatch_router_logits(
self,
@@ -235,107 +234,17 @@ class AgRsAll2AllManager(All2AllManagerBase):
pass
class PPLXAll2AllManager(All2AllManagerBase):
"""
All2All communication based on PPLX kernels.
"""
def __init__(self, cpu_group):
assert has_pplx(), (
"pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
" to install pplx_kernels."
)
super().__init__(cpu_group)
if self.internode:
# inter-node communication needs nvshmem,
# intra-node communication uses p2p mapping directly
from pplx_kernels.nvshmem import ( # type: ignore[import-not-found]
nvshmem_alloc_empty_unique_id,
nvshmem_get_unique_id,
nvshmem_init,
)
logger.debug(
"Initialize NVSHMEM for pplx_kernels: rank=%d, world size=%d",
self.rank,
self.world_size,
)
uid = (
nvshmem_get_unique_id()
if self.rank == 0
else nvshmem_alloc_empty_unique_id()
)
dist.broadcast(
uid,
src=dist.get_process_group_ranks(self.cpu_group)[0],
group=self.cpu_group,
)
logger.debug("PPLX NVSHMEM UID = %s", uid)
nvshmem_init(uid, self.rank, self.world_size)
self.handle_cache = Cache()
def get_handle(self, kwargs):
import pplx_kernels as pplx # type: ignore[import-not-found]
return self.handle_cache.get_or_create(
kwargs,
pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode,
)
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
raise NotImplementedError
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
raise NotImplementedError
def destroy(self):
with self.handle_cache._lock:
for _, handle in self.handle_cache._cache.items():
handle.destroy()
if self.internode:
from pplx_kernels.nvshmem import (
nvshmem_finalize, # type: ignore[import-not-found]
)
logger.debug("PPLX NVSHMEM finalize")
nvshmem_finalize()
class DeepEPAll2AllManagerBase(All2AllManagerBase):
"""
All2All communication based on DeepEP High-Throughput kernels.
"""
def __init__(self, cpu_group):
def __init__(self, cpu_group, tcp_store_group=None):
assert has_deep_ep(), (
"DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
" to install DeepEP kernels."
) # noqa
super().__init__(cpu_group)
super().__init__(cpu_group, tcp_store_group)
self.handle_cache = Cache()
# This is the DeepEP default. Stick to it till we can establish
@@ -373,7 +282,10 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
raise NotImplementedError
def destroy(self):
pass
with self.handle_cache._lock:
for _, handle in self.handle_cache._cache.items():
handle.destroy()
self.handle_cache._cache.clear()
class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
@@ -381,8 +293,8 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
All2All communication based on DeepEP High-Throughput kernels.
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)
def _make_all2all_kwargs(self) -> dict[Any, Any]:
# Defaults for internode and intranode are taken from DeepEP tests.
@@ -405,6 +317,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=False,
num_qps_per_rank=num_qps_per_rank,
explicitly_destroy=True,
)
def get_handle(self, kwargs):
@@ -438,8 +351,8 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
All2All communication based on DeepEP Low-Latency kernels.
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)
def _make_all2all_kwargs(
self,
@@ -476,8 +389,9 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=num_qps_per_rank,
allow_nvlink_for_low_latency_mode=True,
allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
# allow_nvlink_for_low_latency_mode=True,
# allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
explicitly_destroy=True,
)
def get_handle(self, kwargs):
@@ -509,11 +423,11 @@ class FlashInferAllToAllManager(All2AllManagerBase):
rank: int
world_size: int
def __init__(self, cpu_group):
def __init__(self, cpu_group, tcp_store_group=None):
assert has_flashinfer_all2all(), (
"flashinfer all2all module not found. Please install/check flashinfer"
) # noqa
super().__init__(cpu_group)
super().__init__(cpu_group, tcp_store_group)
logger.debug(
"Initialize for flashinfer All2All rank=%d, world size=%d",
self.rank,