Clean up custom allreduce (#4029)
This commit is contained in:
@@ -22,17 +22,18 @@ from sglang.srt.utils import cuda_device_count_stateless, is_cuda, is_hip
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
is_hip_ = is_hip()
|
||||
|
||||
if is_cuda():
|
||||
try:
|
||||
import pynvml
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import pynvml with %r", e)
|
||||
|
||||
if is_hip():
|
||||
if is_hip_:
|
||||
try:
|
||||
from amdsmi import (
|
||||
AmdSmiException,
|
||||
amdsmi_get_gpu_board_info,
|
||||
amdsmi_get_processor_handles,
|
||||
amdsmi_init,
|
||||
amdsmi_shut_down,
|
||||
@@ -42,9 +43,11 @@ if is_hip():
|
||||
logger.warning("Failed to import amdsmi with %r", e)
|
||||
|
||||
try:
|
||||
if ops.use_vllm_custom_allreduce and not is_hip():
|
||||
if ops.use_vllm_custom_allreduce and not is_hip_:
|
||||
# Use vLLM custom allreduce
|
||||
ops.meta_size()
|
||||
else:
|
||||
# Use custom allreduce from sgl kernel (ROCM and TRT-LLM)
|
||||
import sgl_kernel
|
||||
custom_ar = True
|
||||
except Exception:
|
||||
@@ -60,7 +63,7 @@ _R = TypeVar("_R")
|
||||
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
@wraps(fn)
|
||||
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
if torch.version.hip:
|
||||
if is_hip_:
|
||||
try:
|
||||
amdsmi_init()
|
||||
return fn(*args, **kwargs)
|
||||
@@ -78,7 +81,7 @@ def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
|
||||
@with_nvml_context
|
||||
def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool:
|
||||
if is_hip():
|
||||
if is_hip_:
|
||||
"""
|
||||
query if the set of gpus are fully connected by xgmi (1 hop)
|
||||
"""
|
||||
@@ -142,7 +145,7 @@ def is_weak_contiguous(inp: torch.Tensor):
|
||||
class CustomAllreduce:
|
||||
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
|
||||
_MAX_CAR_SIZE = 8192 * 1024
|
||||
if is_hip():
|
||||
if is_hip_:
|
||||
# crossover is at 16MB buffer size for ROCm
|
||||
_MAX_CAR_SIZE = 2 * 8192 * 1024
|
||||
|
||||
@@ -226,7 +229,7 @@ class CustomAllreduce:
|
||||
# test nvlink first, this will filter out most of the cases
|
||||
# where custom allreduce is not supported
|
||||
# this checks hardware and driver support for NVLink
|
||||
if is_cuda() or is_hip():
|
||||
if is_cuda() or is_hip_:
|
||||
full_nvlink = is_full_nvlink(physical_device_ids, world_size)
|
||||
|
||||
if world_size > 2 and not full_nvlink:
|
||||
@@ -240,7 +243,7 @@ class CustomAllreduce:
|
||||
# this is expensive to compute at the first time
|
||||
# then we cache the result
|
||||
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
|
||||
if not is_hip() and not _can_p2p(rank, world_size):
|
||||
if not is_hip_ and not _can_p2p(rank, world_size):
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because your platform lacks "
|
||||
"GPU P2P capability or P2P test failed. To silence this "
|
||||
@@ -253,7 +256,7 @@ class CustomAllreduce:
|
||||
self.world_size = world_size
|
||||
self.full_nvlink = full_nvlink
|
||||
|
||||
if ops.use_vllm_custom_allreduce and not is_hip():
|
||||
if ops.use_vllm_custom_allreduce and not is_hip_:
|
||||
# Buffers memory are owned by this Python class and passed to C++.
|
||||
# Meta data composes of two parts: meta data for synchronization and a
|
||||
# temporary buffer for storing intermediate allreduce results.
|
||||
@@ -276,7 +279,7 @@ class CustomAllreduce:
|
||||
)
|
||||
ops.register_buffer(self._ptr, self.buffer_ptrs)
|
||||
else:
|
||||
if is_hip():
|
||||
if is_hip_:
|
||||
# meta data buffers need to be "uncached" for signal on MI200
|
||||
self.meta = ops.allocate_meta_buffer(ops.meta_size() + max_size)
|
||||
self.buffer = torch.empty(
|
||||
@@ -415,7 +418,7 @@ class CustomAllreduce:
|
||||
ops.register_buffer(self._ptr, inp, handles, offsets)
|
||||
|
||||
def register_graph_buffers(self):
|
||||
if is_hip():
|
||||
if is_hip_:
|
||||
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
|
||||
handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
|
||||
logger.info("Registering %d cuda graph addresses", len(offset))
|
||||
@@ -451,12 +454,12 @@ class CustomAllreduce:
|
||||
return False
|
||||
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
|
||||
# little performance improvement over NCCL.
|
||||
if ops.use_vllm_custom_allreduce and not is_hip():
|
||||
if ops.use_vllm_custom_allreduce and not is_hip_:
|
||||
if self.world_size == 2 or self.full_nvlink:
|
||||
return inp_size < self.max_size
|
||||
return False
|
||||
|
||||
if is_hip():
|
||||
if is_hip_:
|
||||
if self.full_nvlink:
|
||||
if self.world_size == 8:
|
||||
if self.MSCCL:
|
||||
@@ -529,7 +532,7 @@ class CustomAllreduce:
|
||||
return None
|
||||
if self._IS_CAPTURING:
|
||||
if torch.cuda.is_current_stream_capturing():
|
||||
if is_hip():
|
||||
if is_hip_:
|
||||
return self.all_reduce_reg(input)
|
||||
else:
|
||||
return self.all_reduce(input, registered=True)
|
||||
@@ -538,7 +541,7 @@ class CustomAllreduce:
|
||||
# allreduce is out-of-place.
|
||||
return torch.empty_like(input)
|
||||
else:
|
||||
if is_hip():
|
||||
if is_hip_:
|
||||
# note: outside of cuda graph context,
|
||||
# custom allreduce incurs a cost of cudaMemcpy, which should
|
||||
# be small(<=1% of overall latency) compared to the performance
|
||||
|
||||
Reference in New Issue
Block a user