Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -18,7 +18,7 @@ from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup, TCPStore
|
||||
from torch.distributed import ProcessGroup, Store, TCPStore
|
||||
from torch.distributed.distributed_c10d import (
|
||||
Backend,
|
||||
PrefixStore,
|
||||
@@ -228,6 +228,55 @@ class StatelessProcessGroup:
|
||||
gathered_objs.append(recv_obj)
|
||||
return gathered_objs
|
||||
|
||||
def broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
|
||||
"""Broadcast a tensor from source rank to all other ranks."""
|
||||
if self.rank == src:
|
||||
tensor_bytes = pickle.dumps(tensor)
|
||||
self.expire_data()
|
||||
key = f"broadcast_tensor/{src}/{self.broadcast_send_counter}"
|
||||
self.store.set(key, tensor_bytes)
|
||||
self.broadcast_send_counter += 1
|
||||
self.entries.append((key, time.time()))
|
||||
return tensor
|
||||
else:
|
||||
key = f"broadcast_tensor/{src}/{self.broadcast_recv_src_counter[src]}"
|
||||
tensor = pickle.loads(self.store.get(key))
|
||||
self.broadcast_recv_src_counter[src] += 1
|
||||
return tensor
|
||||
|
||||
def send(self, tensor: torch.Tensor, dst: int):
|
||||
"""Send a tensor to a destination rank."""
|
||||
self.expire_data()
|
||||
key = f"send_tensor/{dst}/{self.send_dst_counter[dst]}"
|
||||
self.store.set(key, pickle.dumps(tensor))
|
||||
self.send_dst_counter[dst] += 1
|
||||
self.entries.append((key, time.time()))
|
||||
|
||||
def recv(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
|
||||
"""Receive a tensor from a source rank."""
|
||||
key = f"send_tensor/{self.rank}/{self.recv_src_counter[src]}"
|
||||
received = pickle.loads(self.store.get(key))
|
||||
self.recv_src_counter[src] += 1
|
||||
tensor.copy_(received)
|
||||
return tensor
|
||||
|
||||
def all_reduce(
|
||||
self, tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM
|
||||
) -> torch.Tensor:
|
||||
"""All-reduce a tensor across all ranks."""
|
||||
tensors = self.all_gather_obj(tensor)
|
||||
result = tensors[0].clone()
|
||||
for t in tensors[1:]:
|
||||
if op == torch.distributed.ReduceOp.SUM:
|
||||
result.add_(t)
|
||||
elif op == torch.distributed.ReduceOp.PRODUCT:
|
||||
result.mul_(t)
|
||||
elif op == torch.distributed.ReduceOp.MAX:
|
||||
result = torch.maximum(result, t)
|
||||
elif op == torch.distributed.ReduceOp.MIN:
|
||||
result = torch.minimum(result, t)
|
||||
return result
|
||||
|
||||
def barrier(self, timeout: float = 30.0):
|
||||
"""A robust barrier to synchronize all ranks.
|
||||
|
||||
@@ -448,8 +497,14 @@ def init_gloo_process_group(
|
||||
|
||||
|
||||
def stateless_init_torch_distributed_process_group(
|
||||
host: str, port: int, rank: int, world_size: int, backend: str
|
||||
) -> ProcessGroup:
|
||||
host: str,
|
||||
port: int,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
backend: str,
|
||||
group_name: str | None = None,
|
||||
return_store: bool = False,
|
||||
) -> ProcessGroup | tuple[ProcessGroup, Store]:
|
||||
"""
|
||||
A replacement for `torch.distributed.init_process_group` that does not
|
||||
pollute the global state. The created ProcessGroup object can be used for
|
||||
@@ -496,25 +551,35 @@ def stateless_init_torch_distributed_process_group(
|
||||
# Use a PrefixStore to avoid accidental overrides of keys used by
|
||||
# different systems (e.g. RPC) in case the store is multi-tenant.
|
||||
prefix_store = PrefixStore(init_method, store)
|
||||
try:
|
||||
|
||||
if backend == "gloo":
|
||||
pg = init_gloo_process_group(
|
||||
prefix_store=prefix_store,
|
||||
group_rank=group_rank,
|
||||
group_size=group_size,
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
return current_platform.stateless_init_device_torch_dist_pg(
|
||||
pg = current_platform.stateless_init_device_torch_dist_pg(
|
||||
backend=backend,
|
||||
prefix_store=prefix_store,
|
||||
group_rank=group_rank,
|
||||
group_size=group_size,
|
||||
timeout=timeout,
|
||||
)
|
||||
except NotImplementedError:
|
||||
# If platform doesn't implement stateless_init_device_torch_dist_pg, it
|
||||
# will raise a NotImplementedError. In this case, we fall back to gloo.
|
||||
return init_gloo_process_group(
|
||||
prefix_store=prefix_store,
|
||||
group_rank=group_rank,
|
||||
group_size=group_size,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if group_name is not None:
|
||||
from torch._C._distributed_c10d import _register_process_group
|
||||
|
||||
pg._set_group_name(group_name)
|
||||
_register_process_group(group_name, pg)
|
||||
|
||||
if return_store:
|
||||
return pg, store
|
||||
else:
|
||||
return pg
|
||||
|
||||
|
||||
def stateless_destroy_torch_distributed_process_group(pg: ProcessGroup) -> None:
|
||||
|
||||
Reference in New Issue
Block a user