diff --git a/python/sglang/srt/disaggregation/decode_kvcache_offload_manager.py b/python/sglang/srt/disaggregation/decode_kvcache_offload_manager.py index c74b4938e..f130c3fbb 100644 --- a/python/sglang/srt/disaggregation/decode_kvcache_offload_manager.py +++ b/python/sglang/srt/disaggregation/decode_kvcache_offload_manager.py @@ -4,7 +4,7 @@ import time import torch -from sglang import ServerArgs +from sglang.srt.server_args import ServerArgs from sglang.srt.managers.cache_controller import HiCacheController from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 87a392d55..4c1e8ddfa 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -127,21 +127,45 @@ class RMSNorm(CustomOp): return output, residual_out return rms_norm(x, self.weight.data, self.variance_epsilon) + # def forward_hip( + # self, + # x: torch.Tensor, + # residual: Optional[torch.Tensor] = None, + # ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + # if not x.is_contiguous(): + # # NOTE: Remove this if aiter kernel supports discontinuous input + # x = x.contiguous() + # if residual is not None: + # if _vllm_version < Version("0.9"): + # fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon) + # return x, residual + # else: + # residual_out = torch.empty_like(x) + # output = torch.empty_like(x) + # fused_add_rms_norm( + # output, + # x, + # residual_out, + # residual, + # self.weight.data, + # self.variance_epsilon, + # ) + # return output, residual_out + # out = torch.empty_like(x) + # rms_norm(out, x, self.weight.data, self.variance_epsilon) + # return out def forward_hip( - self, - x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ): if not x.is_contiguous(): - # NOTE: Remove this if aiter kernel supports discontinuous input x = x.contiguous() + if residual is not None: - if _vllm_version < Version("0.9"): - fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon) - return x, residual - else: - residual_out = torch.empty_like(x) + try: output = torch.empty_like(x) + residual_out = torch.empty_like(x) fused_add_rms_norm( output, x, @@ -151,10 +175,21 @@ class RMSNorm(CustomOp): self.variance_epsilon, ) return output, residual_out + except TypeError: + fused_add_rms_norm( + x, + residual, + self.weight.data, + self.variance_epsilon, + ) + return x, residual + out = torch.empty_like(x) rms_norm(out, x, self.weight.data, self.variance_epsilon) return out + + def forward_native( self, x: torch.Tensor, diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 6d3fb53b0..118ec132f 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -61,7 +61,7 @@ def inplace_fused_experts( topk_ids: torch.Tensor, b1: Optional[torch.Tensor] = None, b2: Optional[torch.Tensor] = None, - activation: str = "silu", + activation: int = 0,#0 silu 1 gelu apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, @@ -79,6 +79,8 @@ def inplace_fused_experts( gemm1_alpha: Optional[float] = None, gemm1_limit: Optional[float] = None, ) -> None: + if isinstance(activation, int): + activation = "silu" if activation == 0 else "gelu" fused_experts_impl( hidden_states, w1, @@ -117,7 +119,7 @@ def inplace_fused_experts_fake( topk_ids: torch.Tensor, b1: Optional[torch.Tensor] = None, b2: Optional[torch.Tensor] = None, - activation: str = "silu", + activation: int = 0,#0 silu 1 gelu apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, @@ -154,7 +156,7 @@ def outplace_fused_experts( topk_ids: torch.Tensor, b1: Optional[torch.Tensor] = None, b2: Optional[torch.Tensor] = None, - activation: str = "silu", + activation: int = 0,#0 silu 1 gelu apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, @@ -173,6 +175,8 @@ def outplace_fused_experts( gemm1_alpha: Optional[float] = None, gemm1_limit: Optional[float] = None, ) -> torch.Tensor: + if isinstance(activation, int): + activation = "silu" if activation == 0 else "gelu" return fused_experts_impl( hidden_states, w1, @@ -211,7 +215,7 @@ def outplace_fused_experts_fake( topk_ids: torch.Tensor, b1: Optional[torch.Tensor] = None, b2: Optional[torch.Tensor] = None, - activation: str = "silu", + activation: int = 0,#0 silu 1 gelu apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, @@ -263,6 +267,13 @@ def fused_experts( block_shape: Optional[List[int]] = None, ): topk_weights, topk_ids, _ = topk_output + act_id = ( + 0 if ( + moe_runner_config.activation == 0 + or (isinstance(moe_runner_config.activation, str) + and moe_runner_config.activation.lower() == "silu") + ) else 1 + ) if moe_runner_config.inplace: assert not moe_runner_config.no_combine, "no combine + inplace makes no sense" torch.ops.sglang.inplace_fused_experts( @@ -273,7 +284,7 @@ def fused_experts( topk_ids, b1, b2, - moe_runner_config.activation, + act_id, moe_runner_config.apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, @@ -301,7 +312,7 @@ def fused_experts( topk_ids, b1, b2, - moe_runner_config.activation, + act_id, moe_runner_config.apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, @@ -345,7 +356,7 @@ def fused_experts_impl( b1: Optional[torch.Tensor] = None, b2: Optional[torch.Tensor] = None, inplace: bool = False, - activation: str = "silu", + activation: int = 0,#0 silu 1 gelu apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, @@ -364,6 +375,9 @@ def fused_experts_impl( gemm1_alpha: Optional[float] = None, gemm1_limit: Optional[float] = None, ): + if isinstance(activation, int): + activation = "silu" if activation == 0 else "gelu" + padded_size = padding_size if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter: padded_size = 0 diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 8df5dffb6..42916486a 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -516,7 +516,7 @@ class ModelRunner: ): server_args.attention_backend = "fa3" elif _is_hip: - server_args.attention_backend = "aiter" + server_args.attention_backend = "triton" elif _is_npu: server_args.attention_backend = "ascend" else: diff --git a/sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh b/sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh index ff4d28d29..4aeedebe4 100644 --- a/sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh +++ b/sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh @@ -165,10 +165,10 @@ DINLINE void start_sync( if (threadIdx.x < ngpus) { // simultaneously write to the corresponding flag of all ranks. // Latency = 1 p2p write - __scoped_atomic_store_n( - &sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM); + __hip_atomic_store( + &sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM); // wait until we got true from all ranks - while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE) < + while (__hip_atomic_load(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT) < flag) ; } @@ -211,16 +211,16 @@ DINLINE void end_sync( if (threadIdx.x < ngpus) { // simultaneously write to the corresponding flag of all ranks. // Latency = 1 p2p write - __scoped_atomic_store_n( + __hip_atomic_store( &sg.signals[threadIdx.x]->end[blockIdx.x][rank], flag, final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE, - __MEMORY_SCOPE_SYSTEM); + __HIP_MEMORY_SCOPE_SYSTEM); // wait until we got true from all ranks - while (__scoped_atomic_load_n( + while (__hip_atomic_load( &self_sg->end[blockIdx.x][threadIdx.x], final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE, - __MEMORY_SCOPE_DEVICE) < flag) + __HIP_MEMORY_SCOPE_AGENT) < flag) ; } __syncthreads(); diff --git a/sgl-kernel/csrc/moe/moe_align_kernel.cu b/sgl-kernel/csrc/moe/moe_align_kernel.cu index 92fd34270..d97487df3 100644 --- a/sgl-kernel/csrc/moe/moe_align_kernel.cu +++ b/sgl-kernel/csrc/moe/moe_align_kernel.cu @@ -21,6 +21,7 @@ limitations under the License. #include "utils.h" +#define WARP_SIZE 64 #define VEC_SIZE 4 using Vec = int4; @@ -45,7 +46,7 @@ __device__ __forceinline__ int warp_exclusive_scan(int v, unsigned mask = 0xffff int original = v; #pragma unroll for (int offset = 1; offset < WARP_SIZE; offset <<= 1) { - int n = __shfl_up_sync(mask, v, offset); + int n = __shfl_up(v, offset); if ((threadIdx.x & (WARP_SIZE - 1)) >= offset) v += n; } return v - original; diff --git a/sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu b/sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu index c9bc8a628..dee3b2296 100644 --- a/sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu +++ b/sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu @@ -60,7 +60,7 @@ template __device__ float convert_to_float(T x) { if constexpr (std::is_same_v) { return __half2float(x); - } else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v) { return __bfloat162float(x); } else if constexpr (std::is_same_v) { return x; @@ -575,8 +575,8 @@ void topk_softmax( renormalize, stream); } else if (dtype == at::ScalarType::BFloat16) { - topkGatingSoftmaxKernelLauncher<__nv_bfloat16>( - reinterpret_cast(gating_output.data_ptr()), + topkGatingSoftmaxKernelLauncher<__hip_bfloat16>( + reinterpret_cast(gating_output.data_ptr()), topk_weights.data_ptr(), topk_indices.data_ptr(), softmax_workspace.data_ptr(), diff --git a/sgl-kernel/include/utils.h b/sgl-kernel/include/utils.h index 5cab0786c..7cfef8136 100644 --- a/sgl-kernel/include/utils.h +++ b/sgl-kernel/include/utils.h @@ -358,25 +358,25 @@ __device__ __forceinline__ dstDtype castFromFloat(float val) { #endif // add FP8 support -#ifndef USE_ROCM -#include -using FP8_TYPE = c10::Float8_e4m3fn; -C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits::max(); -#else // USE_ROCM -#if HIP_FP8_TYPE_FNUZ -#include -using FP8_TYPE = c10::Float8_e4m3fnuz; -constexpr auto FP8_E4M3_MAX = 224.0f; -#else -#if HIP_FP8_TYPE_E4M3 -#include -using FP8_TYPE = c10::Float8_e4m3fn; -C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits::max(); -#else -#error "fp8 is not supported in this processor (arch < gfx942)." -#endif // HIP_FP8_TYPE_E4M3 -#endif // HIP_FP8_TYPE_FNUZ -#endif // USE_ROCM +// #ifndef USE_ROCM +// #include +// using FP8_TYPE = c10::Float8_e4m3fn; +// C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits::max(); +// #else // USE_ROCM +// #if HIP_FP8_TYPE_FNUZ +// #include +// using FP8_TYPE = c10::Float8_e4m3fnuz; +// constexpr auto FP8_E4M3_MAX = 224.0f; +// #else +// #if HIP_FP8_TYPE_E4M3 +// #include +// using FP8_TYPE = c10::Float8_e4m3fn; +// C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits::max(); +// #else +// #error "fp8 is not supported in this processor (arch < gfx942)." +// #endif // HIP_FP8_TYPE_E4M3 +// #endif // HIP_FP8_TYPE_FNUZ +// #endif // USE_ROCM #define FULL_MASK 0xffffffff diff --git a/sgl-kernel/setup_hip.py b/sgl-kernel/setup_hip.py new file mode 100644 index 000000000..67d0eea28 --- /dev/null +++ b/sgl-kernel/setup_hip.py @@ -0,0 +1,100 @@ +# Copyright 2025 SGLang Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import platform +import sys +from pathlib import Path + +from setuptools import find_packages, setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +root = Path(__file__).parent.resolve() +arch = platform.machine().lower() + + +def _get_version(): + with open(root / "pyproject.toml") as f: + for line in f: + if line.startswith("version"): + return line.split("=")[1].strip().strip('"') + + +operator_namespace = "sgl_kernel" +include_dirs = [ + root / "include", + root / "csrc", +] + +sources = [ + "csrc/allreduce/custom_all_reduce.hip", + "csrc/allreduce/quick_all_reduce.cu", + "csrc/common_extension_rocm.cc", + "csrc/elementwise/activation.cu", + "csrc/grammar/apply_token_bitmask_inplace_cuda.cu", + "csrc/moe/moe_align_kernel.cu", + "csrc/moe/moe_topk_softmax_kernels.cu", + "csrc/speculative/eagle_utils.cu", + "csrc/kvcacheio/transfer.cu", +] + +cxx_flags = [ + "-O3", + "-Wno-switch-bool", + "-Wno-macro-redefined", + "-Wno-deprecated-declarations", + "-w", +] +libraries = ["c10", "torch", "torch_python"] +extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", f"-L/usr/lib/{arch}-linux-gnu"] + +hipcc_flags = [ + "-fPIC", + "-O3", + "-std=c++17", + "-D__HIP_PLATFORM_HCC__=1", + "--offload-arch=gfx928", + "--offload-arch=gfx936", + "--gpu-max-threads-per-block=1024", + "-Wno-macro-redefined", + "-Wno-deprecated-declarations", + "-funroll-loops", + "-Rpass-analysis=unroll-loops", + "-w", +] + +ext_modules = [ + CUDAExtension( + name="sgl_kernel.common_ops", + sources=sources, + include_dirs=include_dirs, + extra_compile_args={ + "nvcc": hipcc_flags, + "cxx": cxx_flags, + }, + libraries=libraries, + extra_link_args=extra_link_args, + py_limited_api=False, + ), +] + +setup( + name="sgl-kernel", + version=_get_version(), + packages=find_packages(where="python"), + package_dir={"": "python"}, + ext_modules=ext_modules, + cmdclass={"build_ext": BuildExtension.with_options(use_ninja=True)}, + options={"bdist_wheel": {"py_limited_api": "cp39"}}, +)