Turn on shm_allreduce and shm_allgather for fp16 (#10725)
This commit is contained in:
@@ -2921,7 +2921,7 @@ def get_cpu_ids_by_node():
|
||||
def is_shm_available(dtype, world_size, local_size):
|
||||
return (
|
||||
cpu_has_amx_support()
|
||||
and dtype in [torch.bfloat16, torch.float]
|
||||
and dtype in [torch.bfloat16, torch.float16, torch.float]
|
||||
and world_size >= 1
|
||||
and world_size == local_size
|
||||
)
|
||||
|
||||
118
test/srt/cpu/test_comm.py
Normal file
118
test/srt/cpu/test_comm.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import copy
|
||||
import multiprocessing
|
||||
import os
|
||||
import traceback
|
||||
import unittest
|
||||
from multiprocessing import Process
|
||||
|
||||
import sgl_kernel
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from utils import precision
|
||||
|
||||
from sglang.test.test_utils import CustomTestCase, find_available_port
|
||||
|
||||
|
||||
def run_distributed_test(rank, world_size, master_port, output_writer, fn):
|
||||
try:
|
||||
os.environ["RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(master_port)
|
||||
os.environ["LOCAL_SIZE"] = str(world_size)
|
||||
|
||||
dist.init_process_group("gloo", rank=rank, world_size=world_size)
|
||||
torch.ops.sgl_kernel.initialize(world_size, rank)
|
||||
|
||||
fn(rank, world_size)
|
||||
|
||||
execution_ok = True
|
||||
except Exception as e:
|
||||
print(f"subprocess[{rank=}] has error: {e}", flush=True)
|
||||
traceback.print_exc()
|
||||
execution_ok = False
|
||||
|
||||
output_writer.send(execution_ok)
|
||||
output_writer.close()
|
||||
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def all_reduce_fn(rank, world_size):
|
||||
op = dist.ReduceOp.SUM
|
||||
for dtype in [torch.float32, torch.bfloat16, torch.float16]:
|
||||
tensor = torch.randn(2, 10, dtype=dtype)
|
||||
tensor_shm = copy.deepcopy(tensor)
|
||||
|
||||
dist.all_reduce(tensor, op=op)
|
||||
torch.ops.sgl_kernel.shm_allreduce(tensor_shm, op)
|
||||
|
||||
torch.testing.assert_close(tensor, tensor_shm)
|
||||
|
||||
|
||||
def all_gather_fn(rank, world_size):
|
||||
dim = -1
|
||||
|
||||
for dtype in [torch.float32, torch.bfloat16, torch.float16]:
|
||||
tensor = torch.randn(2, 10, dtype=dtype)
|
||||
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += tensor.dim()
|
||||
|
||||
input_size = tensor.size()
|
||||
output_size = (input_size[0] * world_size,) + input_size[1:]
|
||||
output_tensor = torch.empty(
|
||||
output_size, dtype=tensor.dtype, device=tensor.device
|
||||
)
|
||||
dist.all_gather_into_tensor(output_tensor, tensor)
|
||||
output_tensor = output_tensor.reshape((world_size,) + input_size)
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(
|
||||
input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :]
|
||||
)
|
||||
|
||||
output_shm = torch.ops.sgl_kernel.shm_allgather(tensor, dim)
|
||||
|
||||
torch.testing.assert_close(output_tensor, output_shm)
|
||||
|
||||
|
||||
class TestComm(CustomTestCase):
|
||||
def _spawn_and_check(self, fn, world_size=2):
|
||||
mp.set_start_method("spawn", force=True)
|
||||
master_port = find_available_port(23456)
|
||||
|
||||
processes = []
|
||||
output_reader, output_writer = multiprocessing.Pipe(duplex=False)
|
||||
|
||||
for rank in range(world_size):
|
||||
p = Process(
|
||||
target=run_distributed_test,
|
||||
kwargs=dict(
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
master_port=master_port,
|
||||
output_writer=output_writer,
|
||||
fn=fn,
|
||||
),
|
||||
)
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
for _ in range(world_size):
|
||||
self.assertTrue(output_reader.recv(), "Subprocess fail. Check logs above.")
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
def test_all_reduce(self):
|
||||
self._spawn_and_check(all_reduce_fn)
|
||||
|
||||
def test_all_gather(self):
|
||||
self._spawn_and_check(all_gather_fn)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user