24 lines
947 B
Python
24 lines
947 B
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import torch
|
|
import torch.distributed
|
|
|
|
def tensor_model_parallel_all_reduce_with_odsp(input_: torch.Tensor) -> torch.Tensor:
|
|
"""All-reduce the input tensor across model parallel group."""
|
|
from vllm.distributed import get_tp_group
|
|
try:
|
|
total_bytes = input_.numel() * input_.element_size() * get_tp_group().world_size
|
|
# only support 4M now
|
|
if total_bytes < 4194304:
|
|
from torch_vacc.vacc import all_reduce
|
|
return all_reduce(input_,
|
|
get_tp_group().rank_in_group,
|
|
get_tp_group().world_size,
|
|
get_tp_group().group_id,
|
|
dev_info = get_tp_group().rank_device_infos)
|
|
except Exception as e:
|
|
print("all_reduce by DSP run Fail, now use vccl-ops", e, input_.shape, input_.dtype)
|
|
|
|
return get_tp_group().all_reduce(input_)
|
|
|