Files
enginex-vastai-va16-vllm/vllm_vacc/vllm/distributed/communication_op.py
2026-04-02 04:55:00 +00:00

24 lines
947 B
Python

# SPDX-License-Identifier: Apache-2.0
import torch
import torch.distributed
def tensor_model_parallel_all_reduce_with_odsp(input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group."""
from vllm.distributed import get_tp_group
try:
total_bytes = input_.numel() * input_.element_size() * get_tp_group().world_size
# only support 4M now
if total_bytes < 4194304:
from torch_vacc.vacc import all_reduce
return all_reduce(input_,
get_tp_group().rank_in_group,
get_tp_group().world_size,
get_tp_group().group_id,
dev_info = get_tp_group().rank_device_infos)
except Exception as e:
print("all_reduce by DSP run Fail, now use vccl-ops", e, input_.shape, input_.dtype)
return get_tp_group().all_reduce(input_)