# SPDX-License-Identifier: Apache-2.0 import torch import torch.distributed def tensor_model_parallel_all_reduce_with_odsp(input_: torch.Tensor) -> torch.Tensor: """All-reduce the input tensor across model parallel group.""" from vllm.distributed import get_tp_group try: total_bytes = input_.numel() * input_.element_size() * get_tp_group().world_size # only support 4M now if total_bytes < 4194304: from torch_vacc.vacc import all_reduce return all_reduce(input_, get_tp_group().rank_in_group, get_tp_group().world_size, get_tp_group().group_id, dev_info = get_tp_group().rank_device_infos) except Exception as e: print("all_reduce by DSP run Fail, now use vccl-ops", e, input_.shape, input_.dtype) return get_tp_group().all_reduce(input_)