init
This commit is contained in:
23
vllm_vacc/vllm/distributed/communication_op.py
Normal file
23
vllm_vacc/vllm/distributed/communication_op.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
def tensor_model_parallel_all_reduce_with_odsp(input_: torch.Tensor) -> torch.Tensor:
|
||||
"""All-reduce the input tensor across model parallel group."""
|
||||
from vllm.distributed import get_tp_group
|
||||
try:
|
||||
total_bytes = input_.numel() * input_.element_size() * get_tp_group().world_size
|
||||
# only support 4M now
|
||||
if total_bytes < 4194304:
|
||||
from torch_vacc.vacc import all_reduce
|
||||
return all_reduce(input_,
|
||||
get_tp_group().rank_in_group,
|
||||
get_tp_group().world_size,
|
||||
get_tp_group().group_id,
|
||||
dev_info = get_tp_group().rank_device_infos)
|
||||
except Exception as e:
|
||||
print("all_reduce by DSP run Fail, now use vccl-ops", e, input_.shape, input_.dtype)
|
||||
|
||||
return get_tp_group().all_reduce(input_)
|
||||
|
||||
Reference in New Issue
Block a user