Clean up custom allreduce (#4029)

This commit is contained in:
Lianmin Zheng
2025-03-03 04:59:53 -08:00
committed by GitHub
parent 66301e124f
commit a3ab768a2b
2 changed files with 24 additions and 77 deletions

View File

@@ -22,17 +22,18 @@ from sglang.srt.utils import cuda_device_count_stateless, is_cuda, is_hip
logger = logging.getLogger(__name__)
is_hip_ = is_hip()
if is_cuda():
try:
import pynvml
except ImportError as e:
logger.warning("Failed to import pynvml with %r", e)
if is_hip():
if is_hip_:
try:
from amdsmi import (
AmdSmiException,
amdsmi_get_gpu_board_info,
amdsmi_get_processor_handles,
amdsmi_init,
amdsmi_shut_down,
@@ -42,9 +43,11 @@ if is_hip():
logger.warning("Failed to import amdsmi with %r", e)
try:
if ops.use_vllm_custom_allreduce and not is_hip():
if ops.use_vllm_custom_allreduce and not is_hip_:
# Use vLLM custom allreduce
ops.meta_size()
else:
# Use custom allreduce from sgl kernel (ROCM and TRT-LLM)
import sgl_kernel
custom_ar = True
except Exception:
@@ -60,7 +63,7 @@ _R = TypeVar("_R")
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
@wraps(fn)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
if torch.version.hip:
if is_hip_:
try:
amdsmi_init()
return fn(*args, **kwargs)
@@ -78,7 +81,7 @@ def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
@with_nvml_context
def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool:
if is_hip():
if is_hip_:
"""
query if the set of gpus are fully connected by xgmi (1 hop)
"""
@@ -142,7 +145,7 @@ def is_weak_contiguous(inp: torch.Tensor):
class CustomAllreduce:
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
_MAX_CAR_SIZE = 8192 * 1024
if is_hip():
if is_hip_:
# crossover is at 16MB buffer size for ROCm
_MAX_CAR_SIZE = 2 * 8192 * 1024
@@ -226,7 +229,7 @@ class CustomAllreduce:
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
if is_cuda() or is_hip():
if is_cuda() or is_hip_:
full_nvlink = is_full_nvlink(physical_device_ids, world_size)
if world_size > 2 and not full_nvlink:
@@ -240,7 +243,7 @@ class CustomAllreduce:
# this is expensive to compute at the first time
# then we cache the result
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
if not is_hip() and not _can_p2p(rank, world_size):
if not is_hip_ and not _can_p2p(rank, world_size):
logger.warning(
"Custom allreduce is disabled because your platform lacks "
"GPU P2P capability or P2P test failed. To silence this "
@@ -253,7 +256,7 @@ class CustomAllreduce:
self.world_size = world_size
self.full_nvlink = full_nvlink
if ops.use_vllm_custom_allreduce and not is_hip():
if ops.use_vllm_custom_allreduce and not is_hip_:
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
@@ -276,7 +279,7 @@ class CustomAllreduce:
)
ops.register_buffer(self._ptr, self.buffer_ptrs)
else:
if is_hip():
if is_hip_:
# meta data buffers need to be "uncached" for signal on MI200
self.meta = ops.allocate_meta_buffer(ops.meta_size() + max_size)
self.buffer = torch.empty(
@@ -415,7 +418,7 @@ class CustomAllreduce:
ops.register_buffer(self._ptr, inp, handles, offsets)
def register_graph_buffers(self):
if is_hip():
if is_hip_:
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
logger.info("Registering %d cuda graph addresses", len(offset))
@@ -451,12 +454,12 @@ class CustomAllreduce:
return False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
if ops.use_vllm_custom_allreduce and not is_hip():
if ops.use_vllm_custom_allreduce and not is_hip_:
if self.world_size == 2 or self.full_nvlink:
return inp_size < self.max_size
return False
if is_hip():
if is_hip_:
if self.full_nvlink:
if self.world_size == 8:
if self.MSCCL:
@@ -529,7 +532,7 @@ class CustomAllreduce:
return None
if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing():
if is_hip():
if is_hip_:
return self.all_reduce_reg(input)
else:
return self.all_reduce(input, registered=True)
@@ -538,7 +541,7 @@ class CustomAllreduce:
# allreduce is out-of-place.
return torch.empty_like(input)
else:
if is_hip():
if is_hip_:
# note: outside of cuda graph context,
# custom allreduce incurs a cost of cudaMemcpy, which should
# be small(<=1% of overall latency) compared to the performance