sgl-kernel transfer custom allreduce from trt kernel to vllm kernel (#5079)

This commit is contained in:
Yi Zhang
2025-04-06 05:23:20 +08:00
committed by GitHub
parent 0d99adb715
commit bcbbf519f9
10 changed files with 692 additions and 937 deletions

View File

@@ -16,7 +16,6 @@ from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibra
def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes):
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
ranks = list(range(world_size))
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
dist.init_process_group(
backend="nccl",
@@ -26,39 +25,18 @@ def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes)
)
group = dist.group.WORLD
buffer_max_size = 8 * 1024 * 1024
barrier_max_size = 8 * (24 + 2) * 8
buffer_ptrs = None
tmp_result_buffer_ptrs = None
barrier_in_ptrs = None
barrier_out_ptrs = None
custom_ptr = None
try:
buffer_ptrs = TestCustomAllReduce.create_shared_buffer(
buffer_max_size, group=group
)
tmp_result_buffer_ptrs = TestCustomAllReduce.create_shared_buffer(
buffer_max_size, group=group
)
barrier_in_ptrs = TestCustomAllReduce.create_shared_buffer(
barrier_max_size, group=group
)
barrier_out_ptrs = TestCustomAllReduce.create_shared_buffer(
barrier_max_size, group=group
device = torch.device(f"cuda:{rank}")
max_size = 8192 * 1024
meta_ptrs = TestCustomAllReduce.create_shared_buffer(
custom_ops.meta_size() + max_size, group=group
)
rank_data = torch.empty(8 * 1024 * 1024, dtype=torch.uint8, device=device)
buffer_ptrs = TestCustomAllReduce.create_shared_buffer(max_size, group=group)
custom_ptr = custom_ops.init_custom_reduce(
rank,
world_size,
rank_data,
buffer_ptrs,
tmp_result_buffer_ptrs,
barrier_in_ptrs,
barrier_out_ptrs,
)
custom_ptr = custom_ops.init_custom_ar(meta_ptrs, rank_data, rank, True)
custom_ops.register_buffer(custom_ptr, buffer_ptrs)
test_loop = 10
for sz in test_sizes:
@@ -68,7 +46,9 @@ def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes)
inp1_ref = inp1.clone()
out1 = torch.empty_like(inp1)
custom_ops.custom_reduce(custom_ptr, inp1, out1)
custom_ops.all_reduce(
custom_ptr, inp1, out1, buffer_ptrs[rank], max_size
)
dist.all_reduce(inp1_ref, group=group)
@@ -77,15 +57,11 @@ def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes)
finally:
dist.barrier(group=group)
if custom_ptr is not None:
custom_ops.custom_dispose(custom_ptr)
custom_ops.dispose(custom_ptr)
if buffer_ptrs:
TestCustomAllReduce.free_shared_buffer(buffer_ptrs, group)
if tmp_result_buffer_ptrs:
TestCustomAllReduce.free_shared_buffer(tmp_result_buffer_ptrs, group)
if barrier_in_ptrs:
TestCustomAllReduce.free_shared_buffer(barrier_in_ptrs, group)
if barrier_out_ptrs:
TestCustomAllReduce.free_shared_buffer(barrier_out_ptrs, group)
if meta_ptrs:
TestCustomAllReduce.free_shared_buffer(meta_ptrs, group)
dist.destroy_process_group(group=group)
@@ -122,7 +98,18 @@ def multi_process_parallel(
class TestCustomAllReduce(unittest.TestCase):
test_sizes = [512, 4096, 32768, 262144, 524288, 1048576, 2097152]
test_sizes = [
512,
2560,
4096,
5120,
7680,
32768,
262144,
524288,
1048576,
2097152,
]
world_sizes = [2, 4, 8]
@staticmethod