sgl-kernel transfer custom allreduce from trt kernel to vllm kernel (#5079)
This commit is contained in:
@@ -16,7 +16,6 @@ from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibra
|
||||
def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes):
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
ranks = list(range(world_size))
|
||||
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
@@ -26,39 +25,18 @@ def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes)
|
||||
)
|
||||
group = dist.group.WORLD
|
||||
|
||||
buffer_max_size = 8 * 1024 * 1024
|
||||
barrier_max_size = 8 * (24 + 2) * 8
|
||||
buffer_ptrs = None
|
||||
tmp_result_buffer_ptrs = None
|
||||
barrier_in_ptrs = None
|
||||
barrier_out_ptrs = None
|
||||
custom_ptr = None
|
||||
|
||||
try:
|
||||
buffer_ptrs = TestCustomAllReduce.create_shared_buffer(
|
||||
buffer_max_size, group=group
|
||||
)
|
||||
tmp_result_buffer_ptrs = TestCustomAllReduce.create_shared_buffer(
|
||||
buffer_max_size, group=group
|
||||
)
|
||||
barrier_in_ptrs = TestCustomAllReduce.create_shared_buffer(
|
||||
barrier_max_size, group=group
|
||||
)
|
||||
barrier_out_ptrs = TestCustomAllReduce.create_shared_buffer(
|
||||
barrier_max_size, group=group
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
max_size = 8192 * 1024
|
||||
meta_ptrs = TestCustomAllReduce.create_shared_buffer(
|
||||
custom_ops.meta_size() + max_size, group=group
|
||||
)
|
||||
|
||||
rank_data = torch.empty(8 * 1024 * 1024, dtype=torch.uint8, device=device)
|
||||
buffer_ptrs = TestCustomAllReduce.create_shared_buffer(max_size, group=group)
|
||||
|
||||
custom_ptr = custom_ops.init_custom_reduce(
|
||||
rank,
|
||||
world_size,
|
||||
rank_data,
|
||||
buffer_ptrs,
|
||||
tmp_result_buffer_ptrs,
|
||||
barrier_in_ptrs,
|
||||
barrier_out_ptrs,
|
||||
)
|
||||
custom_ptr = custom_ops.init_custom_ar(meta_ptrs, rank_data, rank, True)
|
||||
custom_ops.register_buffer(custom_ptr, buffer_ptrs)
|
||||
|
||||
test_loop = 10
|
||||
for sz in test_sizes:
|
||||
@@ -68,7 +46,9 @@ def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes)
|
||||
inp1_ref = inp1.clone()
|
||||
out1 = torch.empty_like(inp1)
|
||||
|
||||
custom_ops.custom_reduce(custom_ptr, inp1, out1)
|
||||
custom_ops.all_reduce(
|
||||
custom_ptr, inp1, out1, buffer_ptrs[rank], max_size
|
||||
)
|
||||
|
||||
dist.all_reduce(inp1_ref, group=group)
|
||||
|
||||
@@ -77,15 +57,11 @@ def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes)
|
||||
finally:
|
||||
dist.barrier(group=group)
|
||||
if custom_ptr is not None:
|
||||
custom_ops.custom_dispose(custom_ptr)
|
||||
custom_ops.dispose(custom_ptr)
|
||||
if buffer_ptrs:
|
||||
TestCustomAllReduce.free_shared_buffer(buffer_ptrs, group)
|
||||
if tmp_result_buffer_ptrs:
|
||||
TestCustomAllReduce.free_shared_buffer(tmp_result_buffer_ptrs, group)
|
||||
if barrier_in_ptrs:
|
||||
TestCustomAllReduce.free_shared_buffer(barrier_in_ptrs, group)
|
||||
if barrier_out_ptrs:
|
||||
TestCustomAllReduce.free_shared_buffer(barrier_out_ptrs, group)
|
||||
if meta_ptrs:
|
||||
TestCustomAllReduce.free_shared_buffer(meta_ptrs, group)
|
||||
|
||||
dist.destroy_process_group(group=group)
|
||||
|
||||
@@ -122,7 +98,18 @@ def multi_process_parallel(
|
||||
|
||||
|
||||
class TestCustomAllReduce(unittest.TestCase):
|
||||
test_sizes = [512, 4096, 32768, 262144, 524288, 1048576, 2097152]
|
||||
test_sizes = [
|
||||
512,
|
||||
2560,
|
||||
4096,
|
||||
5120,
|
||||
7680,
|
||||
32768,
|
||||
262144,
|
||||
524288,
|
||||
1048576,
|
||||
2097152,
|
||||
]
|
||||
world_sizes = [2, 4, 8]
|
||||
|
||||
@staticmethod
|
||||
Reference in New Issue
Block a user