Files
2026-01-09 13:34:11 +08:00

67 lines
1.7 KiB
Python

import contextlib
from typing import Optional
import torch
from torch.distributed import ProcessGroup, ReduceOp
from vllm.logger import init_logger
logger = init_logger(__name__)
try:
from vllm.distributed.device_communicators.pymccl import (MCCLCommunicator,
mcclGetVersion)
except Exception as e:
# in non-MTHREADS environments, we can't import the mccl module
# e.g. when running on machines with AMD GPUs
logger.info("Failed to import MCCL library: %s", e)
logger.info("It is expected if you are not running on Mthreads GPUs.")
pass
comm: Optional["MCCLCommunicator"] = None
def is_initialized() -> bool:
"""Returns whether the NCCL backend is initialized."""
return comm is not None
@contextlib.contextmanager
def set_pymccl_stream(stream: torch.cuda.Stream):
"""Set the cuda stream for communication"""
try:
assert comm is not None
comm.stream = stream
yield
finally:
pass
def init_process_group(group: Optional[ProcessGroup] = None) -> None:
assert not is_initialized()
global comm
logger.info("vLLM is using nccl==%s", mcclGetVersion())
comm = MCCLCommunicator(group=group)
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
"""All-reduces the input tensor across the process group."""
assert input_.is_musa, f"{input_} should be a musa tensor"
assert comm is not None
comm.all_reduce(input_, op)
def destroy_process_group() -> None:
global comm
comm = None
def get_world_size() -> int:
"""Returns the world size."""
assert comm is not None
return comm.world_size
def get_nccl_backend() -> Optional["MCCLCommunicator"]:
return comm