Files
xc-llm-ascend/vllm_ascend/distributed/device_communicators/npu_communicator.py
SILONG ZENG 7faa6878a6 [Lint]Style: Convert vllm-ascend/ to ruff format(Batch #3) (#5978)
### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/attention/mla_v1.py` |
| `vllm_ascend/attention/sfa_v1.py` |
| `vllm_ascend/core/recompute_scheduler.py` |
| `vllm_ascend/core/scheduler_dynamic_batch.py` |
| `vllm_ascend/distributed/device_communicators/npu_communicator.py` |
| `vllm_ascend/distributed/device_communicators/pyhccl.py` |
| `vllm_ascend/distributed/device_communicators/pyhccl_wrapper.py` |

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
2c24bc6996

---------

Signed-off-by: MrZ20 <2609716663@qq.com>
Co-authored-by: Soren <user@SorendeMac-mini.local>
2026-01-24 22:10:18 +08:00

65 lines
2.6 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import torch
import torch.distributed as dist
from vllm.distributed.device_communicators.base_device_communicator import DeviceCommunicatorBase
class NPUCommunicator(DeviceCommunicatorBase):
def __init__(
self,
cpu_group: dist.ProcessGroup,
device: torch.device | None = None,
device_group: dist.ProcessGroup | None = None,
unique_name: str = "",
):
super().__init__(cpu_group, device, device_group, unique_name)
# TODO(hz): Refer to CudaCommunicator's implementation to integrate PyHcclCommunicator
# init device according to rank
self.device = torch.npu.current_device()
def all_to_all(
self,
input_: torch.Tensor,
scatter_dim: int = 0,
gather_dim: int = -1,
scatter_sizes: list[int] | None = None,
gather_sizes: list[int] | None = None,
) -> torch.Tensor:
if scatter_dim < 0:
scatter_dim += input_.dim()
if gather_dim < 0:
gather_dim += input_.dim()
if scatter_sizes is not None and gather_sizes is not None:
input_list = [t.contiguous() for t in torch.split(input_, scatter_sizes, scatter_dim)]
output_list = []
tensor_shape_base = input_list[self.rank].size()
for i in range(self.world_size):
tensor_shape = list(tensor_shape_base)
tensor_shape[gather_dim] = gather_sizes[i]
output_list.append(torch.empty(tensor_shape, dtype=input_.dtype, device=input_.device))
else:
input_list = [t.contiguous() for t in torch.tensor_split(input_, self.world_size, scatter_dim)]
output_list = [torch.empty_like(input_list[i]) for i in range(self.world_size)]
dist.all_to_all(output_list, input_list, group=self.device_group)
output_tensor = torch.cat(output_list, dim=gather_dim).contiguous()
return output_tensor