optimize custom allreduce kernel (#2904)

This commit is contained in:
yizhang2077
2025-01-16 03:04:25 +08:00
committed by GitHub
parent f65c13b559
commit 6cb3974e77
9 changed files with 244 additions and 80 deletions

View File

@@ -10,6 +10,7 @@ from typing import Any, List, Optional, Union
import ray
import torch
import torch.distributed as dist
from sgl_kernel import ops as custom_ops
from torch.distributed import ProcessGroup
from vllm import _custom_ops as vllm_ops
@@ -104,35 +105,38 @@ class TestCustomAllReduce(unittest.TestCase):
multi_process_parallel(world_size, self, self.performance)
def init_custom_allreduce(self, rank, world_size, group):
import sgl_kernel
buffer_max_size = 8 * 1024 * 1024
barrier_max_size = 8 * (24 + 2) * 8
self.buffer_ptrs = self.create_shared_buffer(buffer_max_size, group=group)
self.tmp_result_buffer_ptrs = self.create_shared_buffer(
buffer_max_size, group=group
)
self.barrier_in_ptrs = self.create_shared_buffer(barrier_max_size, group=group)
self.barrier_out_ptrs = self.create_shared_buffer(barrier_max_size, group=group)
self.rank_data = torch.empty(
8 * 1024 * 1024, dtype=torch.uint8, device=torch.device(f"cuda:{rank}")
)
self.custom_ptr = sgl_kernel.ops.init_custom_reduce(
self.custom_ptr = custom_ops.init_custom_reduce(
rank,
world_size,
self.rank_data,
self.buffer_ptrs,
self.tmp_result_buffer_ptrs,
self.barrier_in_ptrs,
self.barrier_out_ptrs,
)
def custom_allreduce(self, inp, out):
import sgl_kernel
sgl_kernel.ops.custom_reduce(self.custom_ptr, inp, out)
custom_ops.custom_reduce(self.custom_ptr, inp, out)
def free_custom_allreduce(self, group):
import sgl_kernel
self.free_shared_buffer(self.buffer_ptrs, group)
self.free_shared_buffer(self.tmp_result_buffer_ptrs, group)
self.free_shared_buffer(self.barrier_in_ptrs, group)
self.free_shared_buffer(self.barrier_out_ptrs, group)
sgl_kernel.ops.custom_dispose(self.custom_ptr)
custom_ops.custom_dispose(self.custom_ptr)
def init_vllm_allreduce(self, rank, group):
self.vllm_rank = rank