[CPU] fix all_reduce and all_gather (#6770)
Co-authored-by: blzheng <beilei.zheng@intel.com>
This commit is contained in:
@@ -42,8 +42,10 @@ from torch.distributed import Backend, ProcessGroup
|
|||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
direct_register_custom_op,
|
direct_register_custom_op,
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
|
get_int_env_var,
|
||||||
is_cuda_alike,
|
is_cuda_alike,
|
||||||
is_npu,
|
is_npu,
|
||||||
|
is_shm_available,
|
||||||
supports_custom_op,
|
supports_custom_op,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -222,6 +224,7 @@ class GroupCoordinator:
|
|||||||
self.local_rank = local_rank
|
self.local_rank = local_rank
|
||||||
self.device_group = None
|
self.device_group = None
|
||||||
self.cpu_group = None
|
self.cpu_group = None
|
||||||
|
self.local_size = get_int_env_var("LOCAL_SIZE", 0)
|
||||||
|
|
||||||
for ranks in group_ranks:
|
for ranks in group_ranks:
|
||||||
device_group = torch.distributed.new_group(
|
device_group = torch.distributed.new_group(
|
||||||
@@ -440,9 +443,12 @@ class GroupCoordinator:
|
|||||||
return input_
|
return input_
|
||||||
|
|
||||||
if input_.is_cpu:
|
if input_.is_cpu:
|
||||||
import intel_extension_for_pytorch as ipex
|
if is_shm_available(input_.dtype, self.world_size, self.local_size):
|
||||||
|
torch.ops.sgl_kernel.shm_allreduce(
|
||||||
ipex.distributed.all_reduce(input_, group=self.device_group)
|
input_, torch.distributed.ReduceOp.SUM
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
torch.distributed.all_reduce(input_, group=self.device_group)
|
||||||
return input_
|
return input_
|
||||||
|
|
||||||
if not supports_custom_op():
|
if not supports_custom_op():
|
||||||
@@ -570,6 +576,16 @@ class GroupCoordinator:
|
|||||||
output_tensor = torch.empty(
|
output_tensor = torch.empty(
|
||||||
output_size, dtype=input_.dtype, device=input_.device
|
output_size, dtype=input_.dtype, device=input_.device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if input_.is_cpu:
|
||||||
|
if is_shm_available(input_.dtype, self.world_size, self.local_size):
|
||||||
|
return torch.ops.sgl_kernel.shm_allgather(input_, dim)
|
||||||
|
else:
|
||||||
|
torch.distributed.all_gather_into_tensor(
|
||||||
|
output_tensor, input_, group=self.device_group
|
||||||
|
)
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
# All-gather.
|
# All-gather.
|
||||||
self.all_gather_into_tensor(output_tensor, input_)
|
self.all_gather_into_tensor(output_tensor, input_)
|
||||||
# Reshape
|
# Reshape
|
||||||
|
|||||||
@@ -506,9 +506,13 @@ class ModelRunner:
|
|||||||
if _is_cpu_amx_available:
|
if _is_cpu_amx_available:
|
||||||
# Bind OpenMP threads to CPU cores
|
# Bind OpenMP threads to CPU cores
|
||||||
torch.ops.sgl_kernel.init_cpu_threads_env(self.local_omp_cpuid)
|
torch.ops.sgl_kernel.init_cpu_threads_env(self.local_omp_cpuid)
|
||||||
|
|
||||||
|
# Set local size to hint SGLang to use shared memory based AllReduce
|
||||||
|
os.environ["LOCAL_SIZE"] = str(self.tp_size)
|
||||||
|
torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"init_cpu_threads_env is skipped since intel amx backend is not available"
|
"init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only initialize the distributed environment on the target model worker.
|
# Only initialize the distributed environment on the target model worker.
|
||||||
|
|||||||
@@ -2612,3 +2612,12 @@ def get_cpu_ids_by_node():
|
|||||||
|
|
||||||
# ['0,1,2,3', '4,5,6,7', '8,9,10,11', '12,13,14,15', '16,17,18,19', '20,21,22,23']
|
# ['0,1,2,3', '4,5,6,7', '8,9,10,11', '12,13,14,15', '16,17,18,19', '20,21,22,23']
|
||||||
return cpu_ids
|
return cpu_ids
|
||||||
|
|
||||||
|
|
||||||
|
def is_shm_available(dtype, world_size, local_size):
|
||||||
|
return (
|
||||||
|
cpu_has_amx_support()
|
||||||
|
and dtype in [torch.bfloat16, torch.float]
|
||||||
|
and world_size >= 1
|
||||||
|
and world_size == local_size
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user