Files
xc-llm-kunlun/vllm_kunlun/distributed/kunlun_communicator.py
2025-12-10 17:51:24 +08:00

85 lines
2.7 KiB
Python

"""kunlun_communicator"""
from contextlib import contextmanager
from typing import Optional
import torch
from torch.distributed import ProcessGroup
from vllm.distributed.device_communicators.base_device_communicator import DeviceCommunicatorBase
from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator
class KunlunCommunicator(CudaCommunicator):
"""KunlunCommunicator"""
def __init__(self,
device,
device_group,
cpu_group,
unique_name):
"""
Initializes the CUDA Communicator.
Args:
cpu_group (ProcessGroup): The CPU process group.
device (Optional[torch.device], optional): The device to use. Defaults to None.
device_group (Optional[ProcessGroup], optional): The device process group. Defaults to None.
unique_name (str, optional): The unique name of this communicator. Defaults to "".
Raises:
ValueError: If both ``device`` and ``device_group`` are not specified.
"""
DeviceCommunicatorBase.__init__(self, cpu_group, device, device_group, unique_name)
self.ca_comm = None
self.disabled = False
with torch.cuda.device(device):
self.stream = torch.cuda.Stream()
# A small all_reduce for warmup.
data = torch.zeros(1, device=device)
self.all_reduce(data)
self.stream.synchronize()
del data
def all_reduce(self, input_):
"""all_reduce"""
return DeviceCommunicatorBase.all_reduce(self, input_)
def all_gather(self, input_, dim):
"""all_gather"""
return DeviceCommunicatorBase.all_gather(self, input_, dim)
def gather(self, input_, dst, dim):
"""gather"""
return DeviceCommunicatorBase.gather(self, input_, dst, dim)
def send(self, tensor, dst):
"""send"""
DeviceCommunicatorBase.send(self, tensor, dst)
def recv(self, size, dtype, src):
"""recv"""
return DeviceCommunicatorBase.recv(self, size, dtype, src)
def destroy(self):
"""destroy"""
pass
@contextmanager
def change_state(self, enable, stream):
"""
A context manager to change the state of the communicator.
"""
if enable is None:
# guess a default value when not specified
enable = self.available
if stream is None:
stream = self.stream
old_disable = self.disabled
old_stream = self.stream
self.stream = stream
self.disabled = not enable
yield
self.disabled = old_disable
self.stream = old_stream