Files

24 lines
733 B
Python
Raw Permalink Normal View History

2026-04-24 09:50:34 +08:00
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
from torch.distributed import ProcessGroup
from vllm.distributed.device_communicators.base_device_communicator import (
DeviceCommunicatorBase,
)
class MLUCommunicator(DeviceCommunicatorBase):
def __init__(
self,
cpu_group: ProcessGroup,
device: torch.device | None = None,
device_group: ProcessGroup | None = None,
unique_name: str = ""
):
super().__init__(cpu_group, device, device_group, unique_name)
# init device according to rank
self.device = torch.mlu.current_device()
self.ca_comm: CustomAllreduce | None = None