diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py index 7002b25d5..24c0dd79f 100644 --- a/python/sglang/srt/_custom_ops.py +++ b/python/sglang/srt/_custom_ops.py @@ -1,10 +1,7 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py -import contextlib -import functools -import importlib import logging import os -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import List, Tuple import torch import torch.library @@ -13,8 +10,9 @@ from sglang.srt.utils import is_hip, is_hpu logger = logging.getLogger(__name__) use_vllm_custom_allreduce = os.environ.get("USE_VLLM_CUSTOM_ALLREDUCE", default=True) + if not is_hpu(): - # Remove vllm dependency for custom allreduce on ROCm + # ROCm does not use vllm custom allreduce if use_vllm_custom_allreduce and not is_hip(): try: import vllm._C @@ -27,37 +25,8 @@ if not is_hpu(): logger.warning("Failed to import from custom_ar with %r", e) -def hint_on_error(fn): - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - try: - return fn(*args, **kwargs) - - except NotImplementedError as e: - msg = ( - "Error in calling custom op %s: %s\n" - "Not implemented or built, mostly likely because the current current device " - "does not support this kernel (less likely TORCH_CUDA_ARCH_LIST was set " - "incorrectly while building)" - ) - logger.error(msg, fn.__name__, e) - raise NotImplementedError(msg % (fn.__name__, e)) from e - except AttributeError as e: - msg = ( - "Error in calling custom op %s: %s\n" - "Possibly you have built or installed an obsolete version of vllm.\n" - "Please try a clean build and install of vllm," - "or remove old built files such as vllm/*cpython*.so and build/ ." - ) - logger.error(msg, fn.__name__, e) - raise e - - return wrapper - - if use_vllm_custom_allreduce and not is_hip(): - # custom ar + # vLLM custom allreduce def init_custom_ar( ipc_tensors: List[torch.Tensor], rank_data: torch.Tensor, @@ -96,6 +65,7 @@ if use_vllm_custom_allreduce and not is_hip(): else: if is_hip(): + # ROCM custom allreduce def init_custom_ar( meta: torch.Tensor, @@ -143,7 +113,7 @@ else: return sgl_kernel.ops.get_meta_buffer_ipc_handle(inp) else: - # custom ar + # TRTLLM custom allreduce def init_custom_ar( rank_id: int, world_size: int, @@ -176,29 +146,3 @@ else: fa: int, handles: List[List[int]], offsets: List[List[int]] ) -> None: sgl_kernel.ops.register_graph_buffers(fa, handles, offsets) - - -# temporary fix for https://github.com/vllm-project/vllm/issues/5456 -# TODO: remove this in v0.6.0 -names_and_values = globals() -names_and_values_to_update = {} -# prepare variables to avoid dict size change during iteration -k, v, arg = None, None, None -fn_type = type(lambda x: x) -for k, v in names_and_values.items(): - # find functions that are defined in this file and have torch.Tensor - # in their annotations. `arg == "torch.Tensor"` is used to handle - # the case when users use `import __annotations__` to turn type - # hints into strings. - if ( - isinstance(v, fn_type) - and v.__code__.co_filename == __file__ - and any( - arg is torch.Tensor or arg == "torch.Tensor" - for arg in v.__annotations__.values() - ) - ): - names_and_values_to_update[k] = hint_on_error(v) - -names_and_values.update(names_and_values_to_update) -del names_and_values_to_update, names_and_values, v, k, fn_type diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py index 4eb28ced7..efd943b0e 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py @@ -22,17 +22,18 @@ from sglang.srt.utils import cuda_device_count_stateless, is_cuda, is_hip logger = logging.getLogger(__name__) +is_hip_ = is_hip() + if is_cuda(): try: import pynvml except ImportError as e: logger.warning("Failed to import pynvml with %r", e) -if is_hip(): +if is_hip_: try: from amdsmi import ( AmdSmiException, - amdsmi_get_gpu_board_info, amdsmi_get_processor_handles, amdsmi_init, amdsmi_shut_down, @@ -42,9 +43,11 @@ if is_hip(): logger.warning("Failed to import amdsmi with %r", e) try: - if ops.use_vllm_custom_allreduce and not is_hip(): + if ops.use_vllm_custom_allreduce and not is_hip_: + # Use vLLM custom allreduce ops.meta_size() else: + # Use custom allreduce from sgl kernel (ROCM and TRT-LLM) import sgl_kernel custom_ar = True except Exception: @@ -60,7 +63,7 @@ _R = TypeVar("_R") def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: @wraps(fn) def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: - if torch.version.hip: + if is_hip_: try: amdsmi_init() return fn(*args, **kwargs) @@ -78,7 +81,7 @@ def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: @with_nvml_context def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool: - if is_hip(): + if is_hip_: """ query if the set of gpus are fully connected by xgmi (1 hop) """ @@ -142,7 +145,7 @@ def is_weak_contiguous(inp: torch.Tensor): class CustomAllreduce: _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] _MAX_CAR_SIZE = 8192 * 1024 - if is_hip(): + if is_hip_: # crossover is at 16MB buffer size for ROCm _MAX_CAR_SIZE = 2 * 8192 * 1024 @@ -226,7 +229,7 @@ class CustomAllreduce: # test nvlink first, this will filter out most of the cases # where custom allreduce is not supported # this checks hardware and driver support for NVLink - if is_cuda() or is_hip(): + if is_cuda() or is_hip_: full_nvlink = is_full_nvlink(physical_device_ids, world_size) if world_size > 2 and not full_nvlink: @@ -240,7 +243,7 @@ class CustomAllreduce: # this is expensive to compute at the first time # then we cache the result # On AMD GPU, p2p is always enabled between XGMI connected GPUs - if not is_hip() and not _can_p2p(rank, world_size): + if not is_hip_ and not _can_p2p(rank, world_size): logger.warning( "Custom allreduce is disabled because your platform lacks " "GPU P2P capability or P2P test failed. To silence this " @@ -253,7 +256,7 @@ class CustomAllreduce: self.world_size = world_size self.full_nvlink = full_nvlink - if ops.use_vllm_custom_allreduce and not is_hip(): + if ops.use_vllm_custom_allreduce and not is_hip_: # Buffers memory are owned by this Python class and passed to C++. # Meta data composes of two parts: meta data for synchronization and a # temporary buffer for storing intermediate allreduce results. @@ -276,7 +279,7 @@ class CustomAllreduce: ) ops.register_buffer(self._ptr, self.buffer_ptrs) else: - if is_hip(): + if is_hip_: # meta data buffers need to be "uncached" for signal on MI200 self.meta = ops.allocate_meta_buffer(ops.meta_size() + max_size) self.buffer = torch.empty( @@ -415,7 +418,7 @@ class CustomAllreduce: ops.register_buffer(self._ptr, inp, handles, offsets) def register_graph_buffers(self): - if is_hip(): + if is_hip_: handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr) handles, offsets = self._gather_ipc_meta((bytes(handle), offset)) logger.info("Registering %d cuda graph addresses", len(offset)) @@ -451,12 +454,12 @@ class CustomAllreduce: return False # for 4 or more non NVLink-capable GPUs, custom allreduce provides # little performance improvement over NCCL. - if ops.use_vllm_custom_allreduce and not is_hip(): + if ops.use_vllm_custom_allreduce and not is_hip_: if self.world_size == 2 or self.full_nvlink: return inp_size < self.max_size return False - if is_hip(): + if is_hip_: if self.full_nvlink: if self.world_size == 8: if self.MSCCL: @@ -529,7 +532,7 @@ class CustomAllreduce: return None if self._IS_CAPTURING: if torch.cuda.is_current_stream_capturing(): - if is_hip(): + if is_hip_: return self.all_reduce_reg(input) else: return self.all_reduce(input, registered=True) @@ -538,7 +541,7 @@ class CustomAllreduce: # allreduce is out-of-place. return torch.empty_like(input) else: - if is_hip(): + if is_hip_: # note: outside of cuda graph context, # custom allreduce incurs a cost of cudaMemcpy, which should # be small(<=1% of overall latency) compared to the performance