Clean up custom allreduce (#4029)
This commit is contained in:
@@ -1,10 +1,7 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
|
||||||
import contextlib
|
|
||||||
import functools
|
|
||||||
import importlib
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
from typing import List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.library
|
import torch.library
|
||||||
@@ -13,8 +10,9 @@ from sglang.srt.utils import is_hip, is_hpu
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
use_vllm_custom_allreduce = os.environ.get("USE_VLLM_CUSTOM_ALLREDUCE", default=True)
|
use_vllm_custom_allreduce = os.environ.get("USE_VLLM_CUSTOM_ALLREDUCE", default=True)
|
||||||
|
|
||||||
if not is_hpu():
|
if not is_hpu():
|
||||||
# Remove vllm dependency for custom allreduce on ROCm
|
# ROCm does not use vllm custom allreduce
|
||||||
if use_vllm_custom_allreduce and not is_hip():
|
if use_vllm_custom_allreduce and not is_hip():
|
||||||
try:
|
try:
|
||||||
import vllm._C
|
import vllm._C
|
||||||
@@ -27,37 +25,8 @@ if not is_hpu():
|
|||||||
logger.warning("Failed to import from custom_ar with %r", e)
|
logger.warning("Failed to import from custom_ar with %r", e)
|
||||||
|
|
||||||
|
|
||||||
def hint_on_error(fn):
|
|
||||||
|
|
||||||
@functools.wraps(fn)
|
|
||||||
def wrapper(*args, **kwargs):
|
|
||||||
try:
|
|
||||||
return fn(*args, **kwargs)
|
|
||||||
|
|
||||||
except NotImplementedError as e:
|
|
||||||
msg = (
|
|
||||||
"Error in calling custom op %s: %s\n"
|
|
||||||
"Not implemented or built, mostly likely because the current current device "
|
|
||||||
"does not support this kernel (less likely TORCH_CUDA_ARCH_LIST was set "
|
|
||||||
"incorrectly while building)"
|
|
||||||
)
|
|
||||||
logger.error(msg, fn.__name__, e)
|
|
||||||
raise NotImplementedError(msg % (fn.__name__, e)) from e
|
|
||||||
except AttributeError as e:
|
|
||||||
msg = (
|
|
||||||
"Error in calling custom op %s: %s\n"
|
|
||||||
"Possibly you have built or installed an obsolete version of vllm.\n"
|
|
||||||
"Please try a clean build and install of vllm,"
|
|
||||||
"or remove old built files such as vllm/*cpython*.so and build/ ."
|
|
||||||
)
|
|
||||||
logger.error(msg, fn.__name__, e)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
if use_vllm_custom_allreduce and not is_hip():
|
if use_vllm_custom_allreduce and not is_hip():
|
||||||
# custom ar
|
# vLLM custom allreduce
|
||||||
def init_custom_ar(
|
def init_custom_ar(
|
||||||
ipc_tensors: List[torch.Tensor],
|
ipc_tensors: List[torch.Tensor],
|
||||||
rank_data: torch.Tensor,
|
rank_data: torch.Tensor,
|
||||||
@@ -96,6 +65,7 @@ if use_vllm_custom_allreduce and not is_hip():
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
if is_hip():
|
if is_hip():
|
||||||
|
# ROCM custom allreduce
|
||||||
|
|
||||||
def init_custom_ar(
|
def init_custom_ar(
|
||||||
meta: torch.Tensor,
|
meta: torch.Tensor,
|
||||||
@@ -143,7 +113,7 @@ else:
|
|||||||
return sgl_kernel.ops.get_meta_buffer_ipc_handle(inp)
|
return sgl_kernel.ops.get_meta_buffer_ipc_handle(inp)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# custom ar
|
# TRTLLM custom allreduce
|
||||||
def init_custom_ar(
|
def init_custom_ar(
|
||||||
rank_id: int,
|
rank_id: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
@@ -176,29 +146,3 @@ else:
|
|||||||
fa: int, handles: List[List[int]], offsets: List[List[int]]
|
fa: int, handles: List[List[int]], offsets: List[List[int]]
|
||||||
) -> None:
|
) -> None:
|
||||||
sgl_kernel.ops.register_graph_buffers(fa, handles, offsets)
|
sgl_kernel.ops.register_graph_buffers(fa, handles, offsets)
|
||||||
|
|
||||||
|
|
||||||
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
|
|
||||||
# TODO: remove this in v0.6.0
|
|
||||||
names_and_values = globals()
|
|
||||||
names_and_values_to_update = {}
|
|
||||||
# prepare variables to avoid dict size change during iteration
|
|
||||||
k, v, arg = None, None, None
|
|
||||||
fn_type = type(lambda x: x)
|
|
||||||
for k, v in names_and_values.items():
|
|
||||||
# find functions that are defined in this file and have torch.Tensor
|
|
||||||
# in their annotations. `arg == "torch.Tensor"` is used to handle
|
|
||||||
# the case when users use `import __annotations__` to turn type
|
|
||||||
# hints into strings.
|
|
||||||
if (
|
|
||||||
isinstance(v, fn_type)
|
|
||||||
and v.__code__.co_filename == __file__
|
|
||||||
and any(
|
|
||||||
arg is torch.Tensor or arg == "torch.Tensor"
|
|
||||||
for arg in v.__annotations__.values()
|
|
||||||
)
|
|
||||||
):
|
|
||||||
names_and_values_to_update[k] = hint_on_error(v)
|
|
||||||
|
|
||||||
names_and_values.update(names_and_values_to_update)
|
|
||||||
del names_and_values_to_update, names_and_values, v, k, fn_type
|
|
||||||
|
|||||||
@@ -22,17 +22,18 @@ from sglang.srt.utils import cuda_device_count_stateless, is_cuda, is_hip
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
is_hip_ = is_hip()
|
||||||
|
|
||||||
if is_cuda():
|
if is_cuda():
|
||||||
try:
|
try:
|
||||||
import pynvml
|
import pynvml
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning("Failed to import pynvml with %r", e)
|
logger.warning("Failed to import pynvml with %r", e)
|
||||||
|
|
||||||
if is_hip():
|
if is_hip_:
|
||||||
try:
|
try:
|
||||||
from amdsmi import (
|
from amdsmi import (
|
||||||
AmdSmiException,
|
AmdSmiException,
|
||||||
amdsmi_get_gpu_board_info,
|
|
||||||
amdsmi_get_processor_handles,
|
amdsmi_get_processor_handles,
|
||||||
amdsmi_init,
|
amdsmi_init,
|
||||||
amdsmi_shut_down,
|
amdsmi_shut_down,
|
||||||
@@ -42,9 +43,11 @@ if is_hip():
|
|||||||
logger.warning("Failed to import amdsmi with %r", e)
|
logger.warning("Failed to import amdsmi with %r", e)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if ops.use_vllm_custom_allreduce and not is_hip():
|
if ops.use_vllm_custom_allreduce and not is_hip_:
|
||||||
|
# Use vLLM custom allreduce
|
||||||
ops.meta_size()
|
ops.meta_size()
|
||||||
else:
|
else:
|
||||||
|
# Use custom allreduce from sgl kernel (ROCM and TRT-LLM)
|
||||||
import sgl_kernel
|
import sgl_kernel
|
||||||
custom_ar = True
|
custom_ar = True
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -60,7 +63,7 @@ _R = TypeVar("_R")
|
|||||||
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||||
@wraps(fn)
|
@wraps(fn)
|
||||||
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||||
if torch.version.hip:
|
if is_hip_:
|
||||||
try:
|
try:
|
||||||
amdsmi_init()
|
amdsmi_init()
|
||||||
return fn(*args, **kwargs)
|
return fn(*args, **kwargs)
|
||||||
@@ -78,7 +81,7 @@ def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
|||||||
|
|
||||||
@with_nvml_context
|
@with_nvml_context
|
||||||
def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool:
|
def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool:
|
||||||
if is_hip():
|
if is_hip_:
|
||||||
"""
|
"""
|
||||||
query if the set of gpus are fully connected by xgmi (1 hop)
|
query if the set of gpus are fully connected by xgmi (1 hop)
|
||||||
"""
|
"""
|
||||||
@@ -142,7 +145,7 @@ def is_weak_contiguous(inp: torch.Tensor):
|
|||||||
class CustomAllreduce:
|
class CustomAllreduce:
|
||||||
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
|
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
|
||||||
_MAX_CAR_SIZE = 8192 * 1024
|
_MAX_CAR_SIZE = 8192 * 1024
|
||||||
if is_hip():
|
if is_hip_:
|
||||||
# crossover is at 16MB buffer size for ROCm
|
# crossover is at 16MB buffer size for ROCm
|
||||||
_MAX_CAR_SIZE = 2 * 8192 * 1024
|
_MAX_CAR_SIZE = 2 * 8192 * 1024
|
||||||
|
|
||||||
@@ -226,7 +229,7 @@ class CustomAllreduce:
|
|||||||
# test nvlink first, this will filter out most of the cases
|
# test nvlink first, this will filter out most of the cases
|
||||||
# where custom allreduce is not supported
|
# where custom allreduce is not supported
|
||||||
# this checks hardware and driver support for NVLink
|
# this checks hardware and driver support for NVLink
|
||||||
if is_cuda() or is_hip():
|
if is_cuda() or is_hip_:
|
||||||
full_nvlink = is_full_nvlink(physical_device_ids, world_size)
|
full_nvlink = is_full_nvlink(physical_device_ids, world_size)
|
||||||
|
|
||||||
if world_size > 2 and not full_nvlink:
|
if world_size > 2 and not full_nvlink:
|
||||||
@@ -240,7 +243,7 @@ class CustomAllreduce:
|
|||||||
# this is expensive to compute at the first time
|
# this is expensive to compute at the first time
|
||||||
# then we cache the result
|
# then we cache the result
|
||||||
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
|
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
|
||||||
if not is_hip() and not _can_p2p(rank, world_size):
|
if not is_hip_ and not _can_p2p(rank, world_size):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Custom allreduce is disabled because your platform lacks "
|
"Custom allreduce is disabled because your platform lacks "
|
||||||
"GPU P2P capability or P2P test failed. To silence this "
|
"GPU P2P capability or P2P test failed. To silence this "
|
||||||
@@ -253,7 +256,7 @@ class CustomAllreduce:
|
|||||||
self.world_size = world_size
|
self.world_size = world_size
|
||||||
self.full_nvlink = full_nvlink
|
self.full_nvlink = full_nvlink
|
||||||
|
|
||||||
if ops.use_vllm_custom_allreduce and not is_hip():
|
if ops.use_vllm_custom_allreduce and not is_hip_:
|
||||||
# Buffers memory are owned by this Python class and passed to C++.
|
# Buffers memory are owned by this Python class and passed to C++.
|
||||||
# Meta data composes of two parts: meta data for synchronization and a
|
# Meta data composes of two parts: meta data for synchronization and a
|
||||||
# temporary buffer for storing intermediate allreduce results.
|
# temporary buffer for storing intermediate allreduce results.
|
||||||
@@ -276,7 +279,7 @@ class CustomAllreduce:
|
|||||||
)
|
)
|
||||||
ops.register_buffer(self._ptr, self.buffer_ptrs)
|
ops.register_buffer(self._ptr, self.buffer_ptrs)
|
||||||
else:
|
else:
|
||||||
if is_hip():
|
if is_hip_:
|
||||||
# meta data buffers need to be "uncached" for signal on MI200
|
# meta data buffers need to be "uncached" for signal on MI200
|
||||||
self.meta = ops.allocate_meta_buffer(ops.meta_size() + max_size)
|
self.meta = ops.allocate_meta_buffer(ops.meta_size() + max_size)
|
||||||
self.buffer = torch.empty(
|
self.buffer = torch.empty(
|
||||||
@@ -415,7 +418,7 @@ class CustomAllreduce:
|
|||||||
ops.register_buffer(self._ptr, inp, handles, offsets)
|
ops.register_buffer(self._ptr, inp, handles, offsets)
|
||||||
|
|
||||||
def register_graph_buffers(self):
|
def register_graph_buffers(self):
|
||||||
if is_hip():
|
if is_hip_:
|
||||||
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
|
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
|
||||||
handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
|
handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
|
||||||
logger.info("Registering %d cuda graph addresses", len(offset))
|
logger.info("Registering %d cuda graph addresses", len(offset))
|
||||||
@@ -451,12 +454,12 @@ class CustomAllreduce:
|
|||||||
return False
|
return False
|
||||||
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
|
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
|
||||||
# little performance improvement over NCCL.
|
# little performance improvement over NCCL.
|
||||||
if ops.use_vllm_custom_allreduce and not is_hip():
|
if ops.use_vllm_custom_allreduce and not is_hip_:
|
||||||
if self.world_size == 2 or self.full_nvlink:
|
if self.world_size == 2 or self.full_nvlink:
|
||||||
return inp_size < self.max_size
|
return inp_size < self.max_size
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if is_hip():
|
if is_hip_:
|
||||||
if self.full_nvlink:
|
if self.full_nvlink:
|
||||||
if self.world_size == 8:
|
if self.world_size == 8:
|
||||||
if self.MSCCL:
|
if self.MSCCL:
|
||||||
@@ -529,7 +532,7 @@ class CustomAllreduce:
|
|||||||
return None
|
return None
|
||||||
if self._IS_CAPTURING:
|
if self._IS_CAPTURING:
|
||||||
if torch.cuda.is_current_stream_capturing():
|
if torch.cuda.is_current_stream_capturing():
|
||||||
if is_hip():
|
if is_hip_:
|
||||||
return self.all_reduce_reg(input)
|
return self.all_reduce_reg(input)
|
||||||
else:
|
else:
|
||||||
return self.all_reduce(input, registered=True)
|
return self.all_reduce(input, registered=True)
|
||||||
@@ -538,7 +541,7 @@ class CustomAllreduce:
|
|||||||
# allreduce is out-of-place.
|
# allreduce is out-of-place.
|
||||||
return torch.empty_like(input)
|
return torch.empty_like(input)
|
||||||
else:
|
else:
|
||||||
if is_hip():
|
if is_hip_:
|
||||||
# note: outside of cuda graph context,
|
# note: outside of cuda graph context,
|
||||||
# custom allreduce incurs a cost of cudaMemcpy, which should
|
# custom allreduce incurs a cost of cudaMemcpy, which should
|
||||||
# be small(<=1% of overall latency) compared to the performance
|
# be small(<=1% of overall latency) compared to the performance
|
||||||
|
|||||||
Reference in New Issue
Block a user