python transfer custom allreduce from trt kernel to vllm kernel (#5080)
This commit is contained in:
@@ -47,7 +47,7 @@ runtime_common = [
|
||||
|
||||
srt = [
|
||||
"sglang[runtime_common]",
|
||||
"sgl-kernel==0.0.7",
|
||||
"sgl-kernel==0.0.8",
|
||||
"flashinfer_python==0.2.3",
|
||||
"torch==2.5.1",
|
||||
"cuda-python",
|
||||
|
||||
@@ -27,17 +27,20 @@ if not is_hpu():
|
||||
logger.warning("Failed to import from custom_ar with %r", e)
|
||||
|
||||
|
||||
if use_vllm_custom_allreduce and not is_hip():
|
||||
# vLLM custom allreduce
|
||||
if not is_hip():
|
||||
if use_vllm_custom_allreduce:
|
||||
custom_op = torch.ops._C_custom_ar
|
||||
else:
|
||||
custom_op = sgl_kernel.allreduce
|
||||
|
||||
# custom allreduce
|
||||
def init_custom_ar(
|
||||
ipc_tensors: List[torch.Tensor],
|
||||
rank_data: torch.Tensor,
|
||||
rank: int,
|
||||
full_nvlink: bool,
|
||||
) -> int:
|
||||
return torch.ops._C_custom_ar.init_custom_ar(
|
||||
ipc_tensors, rank_data, rank, full_nvlink
|
||||
)
|
||||
return custom_op.init_custom_ar(ipc_tensors, rank_data, rank, full_nvlink)
|
||||
|
||||
def all_reduce(
|
||||
fa: int,
|
||||
@@ -46,105 +49,69 @@ if use_vllm_custom_allreduce and not is_hip():
|
||||
reg_buffer: int,
|
||||
reg_buffer_sz_bytes: int,
|
||||
) -> None:
|
||||
torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes)
|
||||
custom_op.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes)
|
||||
|
||||
def dispose(fa: int) -> None:
|
||||
torch.ops._C_custom_ar.dispose(fa)
|
||||
custom_op.dispose(fa)
|
||||
|
||||
def meta_size() -> int:
|
||||
return torch.ops._C_custom_ar.meta_size()
|
||||
return custom_op.meta_size()
|
||||
|
||||
def register_buffer(fa: int, ipc_tensors: List[int]) -> None:
|
||||
return torch.ops._C_custom_ar.register_buffer(fa, ipc_tensors)
|
||||
return custom_op.register_buffer(fa, ipc_tensors)
|
||||
|
||||
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
|
||||
return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)
|
||||
return custom_op.get_graph_buffer_ipc_meta(fa)
|
||||
|
||||
def register_graph_buffers(
|
||||
fa: int, handles: List[List[int]], offsets: List[List[int]]
|
||||
) -> None:
|
||||
torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
|
||||
custom_op.register_graph_buffers(fa, handles, offsets)
|
||||
|
||||
else:
|
||||
if is_hip():
|
||||
# ROCM custom allreduce
|
||||
# ROCM custom allreduce
|
||||
|
||||
def init_custom_ar(
|
||||
meta: torch.Tensor,
|
||||
rank_data: torch.Tensor,
|
||||
handles: List[str],
|
||||
offsets: List[int],
|
||||
rank: int,
|
||||
full_nvlink: bool,
|
||||
) -> int:
|
||||
return sgl_kernel.allreduce.init_custom_ar(
|
||||
meta, rank_data, handles, offsets, rank, full_nvlink
|
||||
)
|
||||
def init_custom_ar(
|
||||
meta: torch.Tensor,
|
||||
rank_data: torch.Tensor,
|
||||
handles: List[str],
|
||||
offsets: List[int],
|
||||
rank: int,
|
||||
full_nvlink: bool,
|
||||
) -> int:
|
||||
return sgl_kernel.allreduce.init_custom_ar(
|
||||
meta, rank_data, handles, offsets, rank, full_nvlink
|
||||
)
|
||||
|
||||
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
|
||||
sgl_kernel.allreduce.all_reduce_reg(fa, inp, out)
|
||||
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
|
||||
sgl_kernel.allreduce.all_reduce_reg(fa, inp, out)
|
||||
|
||||
def all_reduce_unreg(
|
||||
fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
|
||||
) -> None:
|
||||
sgl_kernel.allreduce.all_reduce_unreg(fa, inp, reg_buffer, out)
|
||||
def all_reduce_unreg(
|
||||
fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
|
||||
) -> None:
|
||||
sgl_kernel.allreduce.all_reduce_unreg(fa, inp, reg_buffer, out)
|
||||
|
||||
def dispose(fa: int) -> None:
|
||||
sgl_kernel.allreduce.dispose(fa)
|
||||
def dispose(fa: int) -> None:
|
||||
sgl_kernel.allreduce.dispose(fa)
|
||||
|
||||
def meta_size() -> int:
|
||||
return sgl_kernel.allreduce.meta_size()
|
||||
def meta_size() -> int:
|
||||
return sgl_kernel.allreduce.meta_size()
|
||||
|
||||
def register_buffer(
|
||||
fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]
|
||||
) -> None:
|
||||
return sgl_kernel.allreduce.register_buffer(fa, t, handles, offsets)
|
||||
def register_buffer(
|
||||
fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]
|
||||
) -> None:
|
||||
return sgl_kernel.allreduce.register_buffer(fa, t, handles, offsets)
|
||||
|
||||
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:
|
||||
return sgl_kernel.allreduce.get_graph_buffer_ipc_meta(fa)
|
||||
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:
|
||||
return sgl_kernel.allreduce.get_graph_buffer_ipc_meta(fa)
|
||||
|
||||
def register_graph_buffers(
|
||||
fa: int, handles: List[str], offsets: List[List[int]]
|
||||
) -> None:
|
||||
sgl_kernel.allreduce.register_graph_buffers(fa, handles, offsets)
|
||||
def register_graph_buffers(
|
||||
fa: int, handles: List[str], offsets: List[List[int]]
|
||||
) -> None:
|
||||
sgl_kernel.allreduce.register_graph_buffers(fa, handles, offsets)
|
||||
|
||||
def allocate_meta_buffer(size: int) -> torch.Tensor:
|
||||
return sgl_kernel.allreduce.allocate_meta_buffer(size)
|
||||
def allocate_meta_buffer(size: int) -> torch.Tensor:
|
||||
return sgl_kernel.allreduce.allocate_meta_buffer(size)
|
||||
|
||||
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
|
||||
return sgl_kernel.allreduce.get_meta_buffer_ipc_handle(inp)
|
||||
|
||||
else:
|
||||
# TRTLLM custom allreduce
|
||||
def init_custom_ar(
|
||||
rank_id: int,
|
||||
world_size: int,
|
||||
rank_data_base: torch.Tensor,
|
||||
buffers: List[int],
|
||||
tmp_result_buffers: List[int],
|
||||
barrier_in: List[int],
|
||||
barrier_out: List[int],
|
||||
) -> int:
|
||||
return sgl_kernel.init_custom_reduce(
|
||||
rank_id,
|
||||
world_size,
|
||||
rank_data_base,
|
||||
buffers,
|
||||
tmp_result_buffers,
|
||||
barrier_in,
|
||||
barrier_out,
|
||||
)
|
||||
|
||||
def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
|
||||
sgl_kernel.custom_reduce(fa, inp, out)
|
||||
|
||||
def dispose(fa: int) -> None:
|
||||
sgl_kernel.custom_dispose(fa)
|
||||
|
||||
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
|
||||
return sgl_kernel.get_graph_buffer_ipc_meta(fa)
|
||||
|
||||
def register_graph_buffers(
|
||||
fa: int, handles: List[List[int]], offsets: List[List[int]]
|
||||
) -> None:
|
||||
sgl_kernel.register_graph_buffers(fa, handles, offsets)
|
||||
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
|
||||
return sgl_kernel.allreduce.get_meta_buffer_ipc_handle(inp)
|
||||
|
||||
@@ -257,7 +257,7 @@ class CustomAllreduce:
|
||||
self.world_size = world_size
|
||||
self.full_nvlink = full_nvlink
|
||||
|
||||
if ops.use_vllm_custom_allreduce and not _is_hip:
|
||||
if not _is_hip:
|
||||
# Buffers memory are owned by this Python class and passed to C++.
|
||||
# Meta data composes of two parts: meta data for synchronization and a
|
||||
# temporary buffer for storing intermediate allreduce results.
|
||||
@@ -280,56 +280,24 @@ class CustomAllreduce:
|
||||
)
|
||||
ops.register_buffer(self._ptr, self.buffer_ptrs)
|
||||
else:
|
||||
if _is_hip:
|
||||
# meta data buffers need to be "uncached" for signal on MI200
|
||||
self.meta = ops.allocate_meta_buffer(ops.meta_size() + max_size)
|
||||
self.buffer = torch.empty(
|
||||
max_size, dtype=torch.uint8, device=self.device
|
||||
)
|
||||
handle = ops.get_meta_buffer_ipc_handle(self.meta)
|
||||
shard_data = (
|
||||
bytes(handle), # ipc handle to base ptr
|
||||
0, # offset of base ptr
|
||||
)
|
||||
handles, offsets = self._gather_ipc_meta(shard_data)
|
||||
self.rank_data = torch.empty(
|
||||
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
|
||||
)
|
||||
self._ptr = ops.init_custom_ar(
|
||||
self.meta, self.rank_data, handles, offsets, rank, self.full_nvlink
|
||||
)
|
||||
self.register_buffer(self.buffer)
|
||||
self.MSCCL = os.getenv("RCCL_MSCCL_ENABLE", "1") == "1"
|
||||
else:
|
||||
# From TensorRT-LLM getMaxRequiredWorkspaceSize
|
||||
self.max_required_workspace_size = [16 * 1024 * 1024, 8 * 1024 * 1024]
|
||||
# meta data buffers need to be "uncached" for signal on MI200
|
||||
self.meta = ops.allocate_meta_buffer(ops.meta_size() + max_size)
|
||||
self.buffer = torch.empty(max_size, dtype=torch.uint8, device=self.device)
|
||||
handle = ops.get_meta_buffer_ipc_handle(self.meta)
|
||||
shard_data = (
|
||||
bytes(handle), # ipc handle to base ptr
|
||||
0, # offset of base ptr
|
||||
)
|
||||
handles, offsets = self._gather_ipc_meta(shard_data)
|
||||
self.rank_data = torch.empty(
|
||||
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
|
||||
)
|
||||
self._ptr = ops.init_custom_ar(
|
||||
self.meta, self.rank_data, handles, offsets, rank, self.full_nvlink
|
||||
)
|
||||
self.register_buffer(self.buffer)
|
||||
self.MSCCL = os.getenv("RCCL_MSCCL_ENABLE", "1") == "1"
|
||||
|
||||
# sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE;
|
||||
self.barrier_max_size = 8 * (36 + 2) * 8
|
||||
|
||||
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
|
||||
self.tmp_result_buffer_ptrs = self.create_shared_buffer(
|
||||
max_size, group=group
|
||||
)
|
||||
self.rank_data_base = torch.empty(
|
||||
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
|
||||
)
|
||||
self.barrier_in_ptrs = self.create_shared_buffer(
|
||||
self.barrier_max_size, group=group
|
||||
)
|
||||
self.barrier_out_ptrs = self.create_shared_buffer(
|
||||
self.barrier_max_size, group=group
|
||||
)
|
||||
|
||||
self._ptr = ops.init_custom_ar(
|
||||
rank,
|
||||
world_size,
|
||||
self.rank_data_base,
|
||||
self.buffer_ptrs,
|
||||
self.tmp_result_buffer_ptrs,
|
||||
self.barrier_in_ptrs,
|
||||
self.barrier_out_ptrs,
|
||||
)
|
||||
self.disabled = False
|
||||
|
||||
@staticmethod
|
||||
@@ -455,7 +423,7 @@ class CustomAllreduce:
|
||||
return False
|
||||
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
|
||||
# little performance improvement over NCCL.
|
||||
if ops.use_vllm_custom_allreduce and not _is_hip:
|
||||
if not _is_hip:
|
||||
if self.world_size == 2 or self.full_nvlink:
|
||||
return inp_size < self.max_size
|
||||
return False
|
||||
@@ -471,18 +439,6 @@ class CustomAllreduce:
|
||||
return inp_size < self.max_size
|
||||
return False
|
||||
|
||||
if self.world_size == 2:
|
||||
return (
|
||||
inp_size < self.max_size
|
||||
and inp_size < self.max_required_workspace_size[0]
|
||||
)
|
||||
|
||||
if self.full_nvlink:
|
||||
return (
|
||||
inp_size < self.max_size
|
||||
and inp_size < self.max_required_workspace_size[1]
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
# all reduce, assuming inp tensor is IPC registered with register_buffer,
|
||||
@@ -515,15 +471,12 @@ class CustomAllreduce:
|
||||
"""
|
||||
if out is None:
|
||||
out = torch.empty_like(inp)
|
||||
if ops.use_vllm_custom_allreduce:
|
||||
if registered:
|
||||
ops.all_reduce(self._ptr, inp, out, 0, 0)
|
||||
else:
|
||||
ops.all_reduce(
|
||||
self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size
|
||||
)
|
||||
if registered:
|
||||
ops.all_reduce(self._ptr, inp, out, 0, 0)
|
||||
else:
|
||||
ops.all_reduce(self._ptr, inp, out)
|
||||
ops.all_reduce(
|
||||
self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size
|
||||
)
|
||||
return out
|
||||
|
||||
def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
@@ -554,14 +507,9 @@ class CustomAllreduce:
|
||||
def close(self):
|
||||
if not self.disabled and self._ptr:
|
||||
ops.dispose(self._ptr)
|
||||
if ops.use_vllm_custom_allreduce:
|
||||
if _is_cuda:
|
||||
self.free_shared_buffer(self.meta_ptrs)
|
||||
self.free_shared_buffer(self.buffer_ptrs)
|
||||
elif _is_cuda:
|
||||
self.free_shared_buffer(self.buffer_ptrs)
|
||||
self.free_shared_buffer(self.tmp_result_buffer_ptrs)
|
||||
self.free_shared_buffer(self.barrier_in_ptrs)
|
||||
self.free_shared_buffer(self.barrier_out_ptrs)
|
||||
self._ptr = 0
|
||||
|
||||
def __del__(self):
|
||||
|
||||
Reference in New Issue
Block a user