adapt custom allreduce for tensorrt llm (#2511)
This commit is contained in:
@@ -27,7 +27,7 @@ runtime_common = [
|
||||
]
|
||||
srt = [
|
||||
"sglang[runtime_common]", "cuda-python",
|
||||
"sgl-kernel>=0.0.2.post12", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1",
|
||||
"sgl-kernel>=0.0.2.post14", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1",
|
||||
"flashinfer==0.1.6"
|
||||
]
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/_custom_ops.py
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
|
||||
import contextlib
|
||||
import functools
|
||||
import importlib
|
||||
@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
if not is_hpu():
|
||||
try:
|
||||
import custom_ar
|
||||
import sgl_kernel
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import from custom_ar with %r", e)
|
||||
|
||||
@@ -50,46 +50,41 @@ def hint_on_error(fn):
|
||||
|
||||
# custom ar
|
||||
def init_custom_ar(
|
||||
ipc_tensors: List[torch.Tensor],
|
||||
rank_data: torch.Tensor,
|
||||
rank: int,
|
||||
full_nvlink: bool,
|
||||
rank_id: int,
|
||||
world_size: int,
|
||||
rank_data_base: torch.Tensor,
|
||||
buffers: List[int],
|
||||
tmp_result_buffers: List[int],
|
||||
barrier_in: List[int],
|
||||
barrier_out: List[int],
|
||||
) -> int:
|
||||
return torch.ops._C_vllm_ar.init_custom_ar(
|
||||
ipc_tensors, rank_data, rank, full_nvlink
|
||||
return sgl_kernel.ops.init_custom_reduce(
|
||||
rank_id,
|
||||
world_size,
|
||||
rank_data_base,
|
||||
buffers,
|
||||
tmp_result_buffers,
|
||||
barrier_in,
|
||||
barrier_out,
|
||||
)
|
||||
|
||||
|
||||
def all_reduce(
|
||||
fa: int,
|
||||
inp: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
reg_buffer: int,
|
||||
reg_buffer_sz_bytes: int,
|
||||
) -> None:
|
||||
torch.ops._C_vllm_ar.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes)
|
||||
def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
|
||||
sgl_kernel.ops.custom_reduce(fa, inp, out)
|
||||
|
||||
|
||||
def dispose(fa: int) -> None:
|
||||
torch.ops._C_vllm_ar.dispose(fa)
|
||||
|
||||
|
||||
def meta_size() -> int:
|
||||
return torch.ops._C_vllm_ar.meta_size()
|
||||
|
||||
|
||||
def register_buffer(fa: int, ipc_tensors: List[int]) -> None:
|
||||
return torch.ops._C_vllm_ar.register_buffer(fa, ipc_tensors)
|
||||
sgl_kernel.ops.custom_dispose(fa)
|
||||
|
||||
|
||||
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
|
||||
return torch.ops._C_vllm_ar.get_graph_buffer_ipc_meta(fa)
|
||||
return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa)
|
||||
|
||||
|
||||
def register_graph_buffers(
|
||||
fa: int, handles: List[List[int]], offsets: List[List[int]]
|
||||
) -> None:
|
||||
torch.ops._C_vllm_ar.register_graph_buffers(fa, handles, offsets)
|
||||
sgl_kernel.ops.register_graph_buffers(fa, handles, offsets)
|
||||
|
||||
|
||||
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
|
||||
|
||||
@@ -21,7 +21,8 @@ from sglang.srt.distributed.parallel_state import in_the_same_node_as
|
||||
from sglang.srt.utils import cuda_device_count_stateless, is_cuda
|
||||
|
||||
try:
|
||||
ops.meta_size()
|
||||
import sgl_kernel
|
||||
|
||||
custom_ar = True
|
||||
except Exception:
|
||||
# For AMD GPUs and CPUs
|
||||
@@ -29,7 +30,6 @@ except Exception:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
@@ -47,7 +47,7 @@ def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
|
||||
|
||||
@with_nvml_context
|
||||
def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
|
||||
def is_full_nvlink(physical_device_ids: List[int]) -> bool:
|
||||
"""
|
||||
query if the set of gpus are fully connected by nvlink (1 hop)
|
||||
"""
|
||||
@@ -196,32 +196,39 @@ class CustomAllreduce:
|
||||
)
|
||||
return
|
||||
|
||||
self.disabled = False
|
||||
# Buffers memory are owned by this Python class and passed to C++.
|
||||
# Meta data composes of two parts: meta data for synchronization and a
|
||||
# temporary buffer for storing intermediate allreduce results.
|
||||
self.meta_ptrs = self.create_shared_buffer(
|
||||
ops.meta_size() + max_size, group=group
|
||||
)
|
||||
# This is a pre-registered IPC buffer. In eager mode, input tensors
|
||||
# are first copied into this buffer before allreduce is performed
|
||||
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
|
||||
# This is a buffer for storing the tuples of pointers pointing to
|
||||
# IPC buffers from all ranks. Each registered tuple has size of
|
||||
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
|
||||
# is enough for 131072 such tuples. The largest model I've seen only
|
||||
# needs less than 10000 of registered tuples.
|
||||
self.rank_data = torch.empty(
|
||||
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
|
||||
)
|
||||
self.max_size = max_size
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.full_nvlink = full_nvlink
|
||||
self._ptr = ops.init_custom_ar(
|
||||
self.meta_ptrs, self.rank_data, rank, self.full_nvlink
|
||||
|
||||
# From TensorRT-LLM getMaxRequiredWorkspaceSize
|
||||
self.max_required_workspace_size = [16 * 1024 * 1024, 8 * 1024 * 1024]
|
||||
|
||||
# sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE;
|
||||
self.barrier_max_size = 8 * (36 + 2) * 8
|
||||
|
||||
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
|
||||
self.tmp_result_buffer_ptrs = self.create_shared_buffer(max_size, group=group)
|
||||
self.rank_data_base = torch.empty(
|
||||
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
|
||||
)
|
||||
ops.register_buffer(self._ptr, self.buffer_ptrs)
|
||||
self.barrier_in_ptrs = self.create_shared_buffer(
|
||||
self.barrier_max_size, group=group
|
||||
)
|
||||
self.barrier_out_ptrs = self.create_shared_buffer(
|
||||
self.barrier_max_size, group=group
|
||||
)
|
||||
|
||||
self._ptr = ops.init_custom_ar(
|
||||
rank,
|
||||
world_size,
|
||||
self.rank_data_base,
|
||||
self.buffer_ptrs,
|
||||
self.tmp_result_buffer_ptrs,
|
||||
self.barrier_in_ptrs,
|
||||
self.barrier_out_ptrs,
|
||||
)
|
||||
self.disabled = False
|
||||
|
||||
@staticmethod
|
||||
def create_shared_buffer(
|
||||
@@ -300,12 +307,25 @@ class CustomAllreduce:
|
||||
return False
|
||||
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
|
||||
# little performance improvement over NCCL.
|
||||
if self.world_size == 2 or self.full_nvlink:
|
||||
return inp_size < self.max_size
|
||||
if self.world_size == 2:
|
||||
return (
|
||||
inp_size < self.max_size
|
||||
and inp_size < self.max_required_workspace_size[0]
|
||||
)
|
||||
|
||||
if self.full_nvlink:
|
||||
return (
|
||||
inp_size < self.max_size
|
||||
and inp_size < self.max_required_workspace_size[1]
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
def all_reduce(
|
||||
self, inp: torch.Tensor, *, out: torch.Tensor = None, registered: bool = False
|
||||
self,
|
||||
inp: torch.Tensor,
|
||||
*,
|
||||
out: torch.Tensor = None,
|
||||
):
|
||||
"""Performs an out-of-place all reduce.
|
||||
|
||||
@@ -315,12 +335,7 @@ class CustomAllreduce:
|
||||
"""
|
||||
if out is None:
|
||||
out = torch.empty_like(inp)
|
||||
if registered:
|
||||
ops.all_reduce(self._ptr, inp, out, 0, 0)
|
||||
else:
|
||||
ops.all_reduce(
|
||||
self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size
|
||||
)
|
||||
ops.all_reduce(self._ptr, inp, out)
|
||||
return out
|
||||
|
||||
def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
@@ -330,23 +345,22 @@ class CustomAllreduce:
|
||||
return None
|
||||
if self._IS_CAPTURING:
|
||||
if torch.cuda.is_current_stream_capturing():
|
||||
return self.all_reduce(input, registered=True)
|
||||
return self.all_reduce(input)
|
||||
else:
|
||||
# If warm up, mimic the allocation pattern since custom
|
||||
# allreduce is out-of-place.
|
||||
return torch.empty_like(input)
|
||||
else:
|
||||
# Note: outside of cuda graph context, custom allreduce incurs a
|
||||
# cost of cudaMemcpy, which should be small (<=1% of overall
|
||||
# latency) compared to the performance gain of using custom kernels
|
||||
return self.all_reduce(input, registered=False)
|
||||
return self.all_reduce(input)
|
||||
|
||||
def close(self):
|
||||
if not self.disabled and self._ptr:
|
||||
ops.dispose(self._ptr)
|
||||
self._ptr = 0
|
||||
self.free_shared_buffer(self.meta_ptrs)
|
||||
self.free_shared_buffer(self.buffer_ptrs)
|
||||
self.free_shared_buffer(self.tmp_result_buffer_ptrs)
|
||||
self.free_shared_buffer(self.barrier_in_ptrs)
|
||||
self.free_shared_buffer(self.barrier_out_ptrs)
|
||||
self._ptr = 0
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
Reference in New Issue
Block a user