Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -18,7 +18,7 @@ from datetime import timedelta
from typing import Any
import torch
from torch.distributed import ProcessGroup, TCPStore
from torch.distributed import ProcessGroup, Store, TCPStore
from torch.distributed.distributed_c10d import (
Backend,
PrefixStore,
@@ -228,6 +228,55 @@ class StatelessProcessGroup:
gathered_objs.append(recv_obj)
return gathered_objs
def broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
"""Broadcast a tensor from source rank to all other ranks."""
if self.rank == src:
tensor_bytes = pickle.dumps(tensor)
self.expire_data()
key = f"broadcast_tensor/{src}/{self.broadcast_send_counter}"
self.store.set(key, tensor_bytes)
self.broadcast_send_counter += 1
self.entries.append((key, time.time()))
return tensor
else:
key = f"broadcast_tensor/{src}/{self.broadcast_recv_src_counter[src]}"
tensor = pickle.loads(self.store.get(key))
self.broadcast_recv_src_counter[src] += 1
return tensor
def send(self, tensor: torch.Tensor, dst: int):
"""Send a tensor to a destination rank."""
self.expire_data()
key = f"send_tensor/{dst}/{self.send_dst_counter[dst]}"
self.store.set(key, pickle.dumps(tensor))
self.send_dst_counter[dst] += 1
self.entries.append((key, time.time()))
def recv(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
"""Receive a tensor from a source rank."""
key = f"send_tensor/{self.rank}/{self.recv_src_counter[src]}"
received = pickle.loads(self.store.get(key))
self.recv_src_counter[src] += 1
tensor.copy_(received)
return tensor
def all_reduce(
self, tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM
) -> torch.Tensor:
"""All-reduce a tensor across all ranks."""
tensors = self.all_gather_obj(tensor)
result = tensors[0].clone()
for t in tensors[1:]:
if op == torch.distributed.ReduceOp.SUM:
result.add_(t)
elif op == torch.distributed.ReduceOp.PRODUCT:
result.mul_(t)
elif op == torch.distributed.ReduceOp.MAX:
result = torch.maximum(result, t)
elif op == torch.distributed.ReduceOp.MIN:
result = torch.minimum(result, t)
return result
def barrier(self, timeout: float = 30.0):
"""A robust barrier to synchronize all ranks.
@@ -448,8 +497,14 @@ def init_gloo_process_group(
def stateless_init_torch_distributed_process_group(
host: str, port: int, rank: int, world_size: int, backend: str
) -> ProcessGroup:
host: str,
port: int,
rank: int,
world_size: int,
backend: str,
group_name: str | None = None,
return_store: bool = False,
) -> ProcessGroup | tuple[ProcessGroup, Store]:
"""
A replacement for `torch.distributed.init_process_group` that does not
pollute the global state. The created ProcessGroup object can be used for
@@ -496,25 +551,35 @@ def stateless_init_torch_distributed_process_group(
# Use a PrefixStore to avoid accidental overrides of keys used by
# different systems (e.g. RPC) in case the store is multi-tenant.
prefix_store = PrefixStore(init_method, store)
try:
if backend == "gloo":
pg = init_gloo_process_group(
prefix_store=prefix_store,
group_rank=group_rank,
group_size=group_size,
timeout=timeout,
)
else:
from vllm.platforms import current_platform
return current_platform.stateless_init_device_torch_dist_pg(
pg = current_platform.stateless_init_device_torch_dist_pg(
backend=backend,
prefix_store=prefix_store,
group_rank=group_rank,
group_size=group_size,
timeout=timeout,
)
except NotImplementedError:
# If platform doesn't implement stateless_init_device_torch_dist_pg, it
# will raise a NotImplementedError. In this case, we fall back to gloo.
return init_gloo_process_group(
prefix_store=prefix_store,
group_rank=group_rank,
group_size=group_size,
timeout=timeout,
)
if group_name is not None:
from torch._C._distributed_c10d import _register_process_group
pg._set_group_name(group_name)
_register_process_group(group_name, pg)
if return_store:
return pg, store
else:
return pg
def stateless_destroy_torch_distributed_process_group(pg: ProcessGroup) -> None: