optimize custom allreduce kernel (#2904)
This commit is contained in:
@@ -10,6 +10,7 @@ from typing import Any, List, Optional, Union
|
||||
import ray
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from sgl_kernel import ops as custom_ops
|
||||
from torch.distributed import ProcessGroup
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
@@ -104,35 +105,38 @@ class TestCustomAllReduce(unittest.TestCase):
|
||||
multi_process_parallel(world_size, self, self.performance)
|
||||
|
||||
def init_custom_allreduce(self, rank, world_size, group):
|
||||
import sgl_kernel
|
||||
|
||||
buffer_max_size = 8 * 1024 * 1024
|
||||
barrier_max_size = 8 * (24 + 2) * 8
|
||||
|
||||
self.buffer_ptrs = self.create_shared_buffer(buffer_max_size, group=group)
|
||||
self.tmp_result_buffer_ptrs = self.create_shared_buffer(
|
||||
buffer_max_size, group=group
|
||||
)
|
||||
self.barrier_in_ptrs = self.create_shared_buffer(barrier_max_size, group=group)
|
||||
self.barrier_out_ptrs = self.create_shared_buffer(barrier_max_size, group=group)
|
||||
self.rank_data = torch.empty(
|
||||
8 * 1024 * 1024, dtype=torch.uint8, device=torch.device(f"cuda:{rank}")
|
||||
)
|
||||
|
||||
self.custom_ptr = sgl_kernel.ops.init_custom_reduce(
|
||||
self.custom_ptr = custom_ops.init_custom_reduce(
|
||||
rank,
|
||||
world_size,
|
||||
self.rank_data,
|
||||
self.buffer_ptrs,
|
||||
self.tmp_result_buffer_ptrs,
|
||||
self.barrier_in_ptrs,
|
||||
self.barrier_out_ptrs,
|
||||
)
|
||||
|
||||
def custom_allreduce(self, inp, out):
|
||||
import sgl_kernel
|
||||
|
||||
sgl_kernel.ops.custom_reduce(self.custom_ptr, inp, out)
|
||||
custom_ops.custom_reduce(self.custom_ptr, inp, out)
|
||||
|
||||
def free_custom_allreduce(self, group):
|
||||
import sgl_kernel
|
||||
|
||||
self.free_shared_buffer(self.buffer_ptrs, group)
|
||||
self.free_shared_buffer(self.tmp_result_buffer_ptrs, group)
|
||||
self.free_shared_buffer(self.barrier_in_ptrs, group)
|
||||
self.free_shared_buffer(self.barrier_out_ptrs, group)
|
||||
sgl_kernel.ops.custom_dispose(self.custom_ptr)
|
||||
custom_ops.custom_dispose(self.custom_ptr)
|
||||
|
||||
def init_vllm_allreduce(self, rank, group):
|
||||
self.vllm_rank = rank
|
||||
|
||||
Reference in New Issue
Block a user