24 lines
733 B
Python
24 lines
733 B
Python
|
|
# SPDX-License-Identifier: Apache-2.0
|
||
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||
|
|
|
||
|
|
import torch
|
||
|
|
from torch.distributed import ProcessGroup
|
||
|
|
|
||
|
|
from vllm.distributed.device_communicators.base_device_communicator import (
|
||
|
|
DeviceCommunicatorBase,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class MLUCommunicator(DeviceCommunicatorBase):
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
cpu_group: ProcessGroup,
|
||
|
|
device: torch.device | None = None,
|
||
|
|
device_group: ProcessGroup | None = None,
|
||
|
|
unique_name: str = ""
|
||
|
|
):
|
||
|
|
super().__init__(cpu_group, device, device_group, unique_name)
|
||
|
|
# init device according to rank
|
||
|
|
self.device = torch.mlu.current_device()
|
||
|
|
self.ca_comm: CustomAllreduce | None = None
|