diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index 63e9fcdd3..15c2ba077 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -33,6 +33,7 @@ from sglang.srt.utils import ( cpu_has_amx_support, is_cpu, is_cuda, + is_hip, is_npu, set_weight_attrs, ) @@ -42,9 +43,12 @@ _is_cuda = is_cuda() _is_npu = is_npu() _is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() +_is_hip = is_hip() if _is_cuda: from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul +elif _is_hip: + from sgl_kernel import gelu_and_mul, gelu_quick, gelu_tanh_and_mul, silu_and_mul if is_npu(): import torch_npu @@ -126,9 +130,13 @@ class QuickGELU(CustomOp): return x * torch.sigmoid(1.702 * x) def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: - # TODO(zhyncs): Implement the CUDA kernel for QuickGELU in sgl-kernel return self.forward_native(x) + def forward_hip(self, x: torch.Tensor) -> torch.Tensor: + out = torch.empty(x.shape, dtype=x.dtype, device=x.device) + gelu_quick(x, out) + return out + class ScaledActivation(nn.Module): """An activation function with post-scale parameters. @@ -222,8 +230,8 @@ def get_cross_encoder_activation_function(config: PretrainedConfig): return nn.Identity() -if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)): +if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip): logger.info( - "sgl-kernel is not available on Non-NV platforms or Non-AMX CPUs. Fallback to other kernel libraries." + "sgl-kernel is not available on Non-NV, Non-AMD platforms or Non-AMX CPUs. Fallback to other kernel libraries." ) from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul diff --git a/python/sglang/test/test_activation.py b/python/sglang/test/test_activation.py index 38366e92b..dd5c668cf 100644 --- a/python/sglang/test/test_activation.py +++ b/python/sglang/test/test_activation.py @@ -3,9 +3,12 @@ import unittest import torch -from sglang.srt.layers.activation import GeluAndMul +from sglang.srt.layers.activation import GeluAndMul, QuickGELU +from sglang.srt.utils import is_hip from sglang.test.test_utils import CustomTestCase +_is_hip = is_hip() + class TestGeluAndMul(CustomTestCase): DTYPES = [torch.half, torch.bfloat16] @@ -52,5 +55,51 @@ class TestGeluAndMul(CustomTestCase): self._run_gelu_and_mul_test(*params) +class TestQuickGELU(CustomTestCase): + DTYPES = [torch.half, torch.bfloat16] + NUM_TOKENS = [7, 83, 2048] # batch = sequence length + DIMS = [512, 4096, 5120, 13824] # all multiples of 16 bytes + SEEDS = [0] + + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA is not available") + torch.set_default_device("cuda") + + def _run_gelu_quick_test(self, n_tok: int, dim: int, dtype: torch.dtype, seed: int): + torch.manual_seed(seed) + + layer = QuickGELU().to(dtype=dtype) + + x = torch.randn(n_tok, dim, dtype=dtype, device="cuda") + + with torch.inference_mode(): + ref = layer.forward_native(x) # x * sigmoid(1.702 * x), fp32 math + if _is_hip: + out = layer.forward_hip(x) # 128-bit vectorised kernel from sgl-kernel + else: + out = layer.forward_cuda(x) + + tol = 1e-2 if dtype is torch.bfloat16 else 1e-3 + self.assertTrue( + torch.allclose(out, ref, atol=tol, rtol=tol), + msg=f"Mismatch @ B={n_tok}, D={dim}, dtype={dtype}", + ) + print(f"Match @ B={n_tok}, D={dim}, dtype={dtype}") + + def test_quick_gelu(self): + for params in itertools.product( + self.NUM_TOKENS, self.DIMS, self.DTYPES, self.SEEDS + ): + with self.subTest( + num_tokens=params[0], + dim=params[1], + dtype=params[2], + seed=params[3], + ): + self._run_gelu_quick_test(*params) + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/sgl-kernel/benchmark/bench_activation.py b/sgl-kernel/benchmark/bench_activation.py new file mode 100644 index 000000000..cfea78915 --- /dev/null +++ b/sgl-kernel/benchmark/bench_activation.py @@ -0,0 +1,153 @@ +# Benchmarks SGLang kernels versus vLLM across +# (kernel, dtype, batch_size, seq_len, dim) and prints speed-up. +import argparse +import itertools +import re +from typing import List, Tuple + +import sgl_kernel +import torch +import torch.nn.functional as F +import triton +import triton.testing +from sgl_kernel import gelu_quick # activation-only kernel +from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul +from vllm import _custom_ops as vllm_ops + +if not hasattr(vllm_ops, "silu_and_mul"): + vllm_ops = torch.ops._C + + +def str2int_list(arg: str) -> List[int]: + if arg in ("", None): + return [] + if re.fullmatch(r"\d+(,\d+)*", arg.strip()) is None: + raise argparse.ArgumentTypeError(f"Bad int list: {arg}") + return [int(x) for x in arg.split(",")] + + +def calculate_diff( + kernel: str, dtype: torch.dtype, batch_size: int, seq_len: int, dim: int +) -> bool: + """Compare vLLM with SGLang for one shape.""" + device = torch.device("cuda") + + # activation-only quick GELU + if kernel == "gelu_quick": + x = torch.randn(batch_size, seq_len, dim, dtype=dtype, device=device) + ref_out = torch.zeros_like(x) + getattr(vllm_ops, kernel)(ref_out, x) + test_out = getattr(sgl_kernel, kernel)(x) + # fused activation x mul kernels + else: + x = torch.randn(batch_size, seq_len, 2 * dim, dtype=dtype, device=device) + ref_out = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device) + getattr(vllm_ops, kernel)(ref_out, x) + test_out = getattr(sgl_kernel, kernel)(x) + + ok = torch.allclose(ref_out, test_out, rtol=1e-3, atol=1e-5) + tag = "✅ match" if ok else "❌ mismatch" + print( + f"[{kernel:14s} | {str(dtype):9s} | B={batch_size:3d} | " + f"L={seq_len:3d} | D={dim:5d}] {tag}" + ) + return ok + + +kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul", "gelu_quick"] +dtypes = [torch.float16, torch.bfloat16] + + +def make_configs(bsizes: List[int], slens: List[int], dims_: List[int]) -> List[Tuple]: + return list(itertools.product(kernels, dtypes, bsizes, slens, dims_)) + + +default_batch_sizes = [2**i for i in range(0, 5, 2)] # 1,4,16 +default_seq_lens = [2**i for i in range(0, 8, 2)] # 1,4,16,64 +default_dims = [2**i for i in range(7, 15)] # 128...16384 + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["kernel", "dtype", "batch_size", "seq_len", "dim"], + x_vals=[], + line_arg="provider", + line_vals=["vllm", "sglang", "speedup"], + line_names=["vLLM", "SGL Kernel", "Speed-up (x)"], + styles=[("blue", "-"), ("green", "-"), ("red", "--")], + ylabel="µs (median) or × (speed-up)", + plot_name="activation-performance", + args={}, + ) +) +def benchmark(kernel, dtype, batch_size, seq_len, dim, provider): + device = torch.device("cuda") + in_mult = 1 if kernel == "gelu_quick" else 2 + x = torch.randn(batch_size, seq_len, in_mult * dim, dtype=dtype, device=device) + y0 = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device) + + vllm_kernel = getattr(vllm_ops, kernel) + sglang_kernel = getattr(sgl_kernel, kernel) + + def baseline(): + tmp = y0.clone() + vllm_kernel(tmp, x) + return tmp + + def sglang(): + return sglang_kernel(x) + + # one-time correctness check + if provider == "vllm" and not calculate_diff( + kernel, dtype, batch_size, seq_len, dim + ): + raise ValueError("Mismatch – abort benchmark") + + # timing helper + def timed(fn): + for _ in range(5): + fn() + torch.cuda.synchronize() + ms, qmin, qmax = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8]) + return 1000 * ms, 1000 * qmax, 1000 * qmin + + if provider == "vllm": + return timed(baseline) + if provider == "sglang": + return timed(sglang) + + # provider == "speedup" + t_ref, _, _ = timed(baseline) + t_sgl, _, _ = timed(sglang) + spd = t_ref / t_sgl + return (spd, spd, spd) + + +if __name__ == "__main__": + p = argparse.ArgumentParser("Activation kernel benchmark") + p.add_argument("--batch_sizes", type=str2int_list, default=default_batch_sizes) + p.add_argument("--seq_lens", type=str2int_list, default=default_seq_lens) + p.add_argument("--dims", type=str2int_list, default=default_dims) + p.add_argument("--verify_only", action="store_true") + args = p.parse_args() + + # coerce lists + if isinstance(args.batch_sizes, str): + args.batch_sizes = str2int_list(args.batch_sizes) + if isinstance(args.seq_lens, str): + args.seq_lens = str2int_list(args.seq_lens) + if isinstance(args.dims, str): + args.dims = str2int_list(args.dims) + + # patch perf_report grid + benchmark_grid = make_configs(args.batch_sizes, args.seq_lens, args.dims) + if hasattr(benchmark, "benchmarks"): + benchmark.benchmarks.x_vals = benchmark_grid + else: + benchmark.benchmark.x_vals = benchmark_grid + + if args.verify_only: + ok = calculate_diff("gelu_quick", torch.float16, 1, 1, args.dims[0]) + print("✅ sanity pass" if ok else "❌ mismatch") + else: + benchmark.run(print_data=True) diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 20b9a8048..623fbefb5 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -78,13 +78,13 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, bool enable_pdl) -> ()"); m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm); - m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); + m.def("silu_and_mul(Tensor! out, Tensor input) -> ()"); m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); - m.def("gelu_tanh_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); + m.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()"); m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul); - m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); + m.def("gelu_and_mul(Tensor! out, Tensor input) -> ()"); m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul); m.def( diff --git a/sgl-kernel/csrc/elementwise/activation.cu b/sgl-kernel/csrc/elementwise/activation.cu index 242281fd9..20b889530 100644 --- a/sgl-kernel/csrc/elementwise/activation.cu +++ b/sgl-kernel/csrc/elementwise/activation.cu @@ -13,70 +13,158 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + +#include +#include +#include + +#ifndef USE_ROCM + #include -#include "pytorch_extension_utils.h" +#include "utils.h" -using namespace flashinfer; +#else +#include "hip_act_and_mul.cuh" +#endif -__device__ __forceinline__ float silu(const float& val) { - return val / (1.0f + __expf(-val)); +// Adapted from flashinfer activation +// https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/csrc/activation.cu#L44 + +namespace detail { + +template +__device__ __forceinline__ float to_f32(const T& x) { +#if USE_ROCM + return castToFloat(x); +#else + return static_cast(x); +#endif } -__device__ __forceinline__ float gelu(const float& val) { +template +__device__ __forceinline__ T from_f32(float f32) { +#if USE_ROCM + return castFromFloat(f32); +#else + return static_cast(f32); +#endif +} + +} // namespace detail + +template +__device__ __forceinline__ T silu(const T& x) { + float f32_val = detail::to_f32(x); + return detail::from_f32(f32_val / (1.0f + expf(-f32_val))); +} + +template +__device__ __forceinline__ T gelu(const T& x) { constexpr float kAlpha = M_SQRT1_2; - return val * 0.5f * (1.0f + ::erf(val * kAlpha)); + float f32_val = detail::to_f32(x); + return detail::from_f32(f32_val * (0.5f * (1.0f + erf(f32_val * kAlpha)))); } -__device__ __forceinline__ float gelu_tanh(const float& val) { - const float cdf = 0.5f * (1.0f + math::tanh((0.7978845608028654f * (val + 0.044715f * val * val * val)))); - return val * cdf; +// gelu_quick(x) = x * torch.sigmoid(1.702 * x) +template +__device__ __forceinline__ T gelu_quick_act(const T& x) { + float f32_val = detail::to_f32(x); + return detail::from_f32(f32_val / (1.0f + expf(-f32_val * 1.702f))); } -void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream) { +template +__device__ __forceinline__ T gelu_tanh(const T& x) { + constexpr float kAlpha = 0.044715f; + constexpr float kBeta = 0.7978845608028654f; + float f32_val = detail::to_f32(x); + const float cdf = 0.5f * (1.0f + tanhf((kBeta * (f32_val + kAlpha * f32_val * f32_val * f32_val)))); + return detail::from_f32(f32_val * cdf); +} + +void silu_and_mul(at::Tensor& out, at::Tensor& input) { int d = input.size(-1) / 2; int64_t num_tokens = input.numel() / input.size(-1); dim3 grid(num_tokens); - cudaStream_t stream = reinterpret_cast(cuda_stream); - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] { + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { uint32_t vec_size = 16 / sizeof(c_type); dim3 block(std::min(d / vec_size, 1024U)); +#if USE_ROCM + sgl_hip::activation::act_and_mul_kernel + <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); +#else flashinfer::activation::act_and_mul_kernel <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); - +#endif return true; }); } -void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream) { +void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input) { int d = input.size(-1) / 2; int64_t num_tokens = input.numel() / input.size(-1); dim3 grid(num_tokens); - cudaStream_t stream = reinterpret_cast(cuda_stream); - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] { + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { uint32_t vec_size = 16 / sizeof(c_type); dim3 block(std::min(d / vec_size, 1024U)); +#if USE_ROCM + sgl_hip::activation::act_and_mul_kernel + <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); +#else flashinfer::activation::act_and_mul_kernel <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); +#endif + return true; + }); +} + +void gelu_and_mul(at::Tensor& out, at::Tensor& input) { + int d = input.size(-1) / 2; + int64_t num_tokens = input.numel() / input.size(-1); + dim3 grid(num_tokens); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { + uint32_t vec_size = 16 / sizeof(c_type); + dim3 block(std::min(d / vec_size, 1024U)); +#if USE_ROCM + sgl_hip::activation::act_and_mul_kernel + <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); +#else + flashinfer::activation::act_and_mul_kernel + <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); +#endif return true; }); } -void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream) { - int d = input.size(-1) / 2; +#if USE_ROCM +void gelu_quick(at::Tensor& out, const at::Tensor& input) { + int d = input.size(-1); int64_t num_tokens = input.numel() / input.size(-1); dim3 grid(num_tokens); - cudaStream_t stream = reinterpret_cast(cuda_stream); - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] { + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { uint32_t vec_size = 16 / sizeof(c_type); dim3 block(std::min(d / vec_size, 1024U)); - flashinfer::activation::act_and_mul_kernel + sgl_hip::activation::act_only_kernel <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); return true; }); } +#endif diff --git a/sgl-kernel/csrc/torch_extension_rocm.cc b/sgl-kernel/csrc/torch_extension_rocm.cc index 46a50ca6b..9010d0b26 100644 --- a/sgl-kernel/csrc/torch_extension_rocm.cc +++ b/sgl-kernel/csrc/torch_extension_rocm.cc @@ -19,6 +19,20 @@ limitations under the License. #include "sgl_kernel_ops.h" TORCH_LIBRARY_EXPAND(sgl_kernel, m) { + /* + * From csrc/activation + */ + m.def("silu_and_mul(Tensor! out, Tensor input) -> ()"); + m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); + + m.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()"); + m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul); + + m.def("gelu_and_mul(Tensor! out, Tensor input) -> ()"); + m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul); + + m.def("gelu_quick(Tensor! out, Tensor input) -> ()"); + m.impl("gelu_quick", torch::kCUDA, &gelu_quick); /* * From csrc/allreduce */ diff --git a/sgl-kernel/include/hip_act_and_mul.cuh b/sgl-kernel/include/hip_act_and_mul.cuh new file mode 100644 index 000000000..ddb1b702d --- /dev/null +++ b/sgl-kernel/include/hip_act_and_mul.cuh @@ -0,0 +1,87 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include "utils.h" + +#define kBitsToLoad 128 +#define kBytesToLoad (kBitsToLoad / 8) + +// Adapted from +// [flashinfer::activation::act_and_mul_kernel](https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/include/flashinfer/activation.cuh#L29) + +namespace sgl_hip { +namespace activation { + +template +__global__ void act_and_mul_kernel(T* __restrict__ out, const T* __restrict__ input, const int d) { + constexpr uint32_t vec_size = kBytesToLoad / sizeof(T); + const int64_t token_idx = blockIdx.x; + const int64_t thread_idx = threadIdx.x; + const int64_t stride = blockDim.x; + const int64_t offset = token_idx * 2 * d; + +#pragma unroll 1 + for (uint32_t idx = thread_idx; idx < d / vec_size; idx += stride) { + sgl_hip::vec_t x_vec, y_vec, out_vec; + x_vec.cast_load(input + offset + idx * vec_size); + y_vec.cast_load(input + offset + d + idx * vec_size); +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + out_vec[i] = Activation(x_vec[i]) * y_vec[i]; + } + out_vec.cast_store(out + token_idx * d + idx * vec_size); + } + + const int64_t remaining_offset = d - d % (stride * vec_size); + // process the remaining elements +#pragma unroll 1 + for (int64_t idx = thread_idx; idx < d % (stride * vec_size); idx += stride) { + T x = input[offset + remaining_offset + idx], y = input[offset + remaining_offset + d + idx]; + out[token_idx * d + remaining_offset + idx] = Activation(x) * y; + } +} + +template +__global__ void act_only_kernel(T* __restrict__ out, const T* __restrict__ input, const int d) { + constexpr uint32_t vec_size = kBytesToLoad / sizeof(T); + const int64_t token_idx = blockIdx.x; + const int64_t thread_idx = threadIdx.x; + const int64_t stride = blockDim.x; + const int64_t offset = token_idx * d; + +#pragma unroll 1 + for (uint32_t idx = thread_idx; idx < d / vec_size; idx += stride) { + sgl_hip::vec_t x_vec, y_vec, out_vec; + x_vec.cast_load(input + offset + idx * vec_size); +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + out_vec[i] = Activation(x_vec[i]); + } + out_vec.cast_store(out + token_idx * d + idx * vec_size); + } + + const int64_t remaining_offset = d - d % (stride * vec_size); + // process the remaining elements +#pragma unroll 1 + for (int64_t idx = thread_idx; idx < d % (stride * vec_size); idx += stride) { + T x = input[offset + remaining_offset + idx]; + out[token_idx * d + remaining_offset + idx] = Activation(x); + } +} + +} // namespace activation +} // namespace sgl_hip diff --git a/sgl-kernel/include/hip_math_def.h b/sgl-kernel/include/hip_math_def.h new file mode 100644 index 000000000..21cc67456 --- /dev/null +++ b/sgl-kernel/include/hip_math_def.h @@ -0,0 +1,94 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#if defined(__HIP_PLATFORM_AMD__) + +#include +#include +#include + +// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491) + +namespace amdgpu { + +template +__forceinline__ __device__ T shfl_xor_sync(unsigned mask, T var, int laneMask, int width = warpSize); + +template +__forceinline__ __device__ destDtype cast(srcDtype val); + +// specialization +template <> +__forceinline__ __device__ float shfl_xor_sync(unsigned mask, float var, int laneMask, int width) { + return __shfl_xor(var, laneMask, width); +} + +template <> +__forceinline__ __device__ int shfl_xor_sync(unsigned mask, int var, int laneMask, int width) { + return __shfl_xor(var, laneMask, width); +} + +template <> +__forceinline__ __device__ float cast(float val) { + return val; +} + +template <> +__forceinline__ __device__ float cast(__half val) { + return __half2float(val); +} + +template <> +__forceinline__ __device__ float cast(__hip_bfloat16 val) { + return __bfloat162float(val); +} + +template <> +__forceinline__ __device__ __half cast(float fval) { + return __float2half(fval); +} + +template <> +__forceinline__ __device__ __hip_bfloat16 cast(float fval) { + return __float2bfloat16(fval); +} + +} // namespace amdgpu + +template +__forceinline__ __device__ T __shfl_xor_sync(unsigned mask, T var, int laneMask, int width = warpSize) { + return amdgpu::shfl_xor_sync(mask, var, laneMask, width); +} + +template +__device__ __forceinline__ float castToFloat(srcDtype val) { + return amdgpu::cast(val); +} + +template +__device__ __forceinline__ dstDtype castFromFloat(float val) { + return amdgpu::cast(val); +} + +// operator overload to support flashinfer +__host__ __device__ __forceinline__ __half operator*(const __half& x, const __half& y) { + __half h_x = x; + __half h_y = y; + return __hmul(h_x, h_y); +} + +#endif diff --git a/sgl-kernel/include/hip_vec_dtypes.h b/sgl-kernel/include/hip_vec_dtypes.h new file mode 100644 index 000000000..a68a6986e --- /dev/null +++ b/sgl-kernel/include/hip_vec_dtypes.h @@ -0,0 +1,101 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#if USE_ROCM + +#include +#include +#include + +// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491)d + +#define SGL_HIP_INLINE inline __attribute__((always_inline)) __device__ + +namespace sgl_hip { + +template +struct vec_t; + +template +SGL_HIP_INLINE void cast_load_impl(vec_t& dst, const srcDtype* src); + +template +SGL_HIP_INLINE void cast_store_impl(dstDtype* dst_ptr, const vec_t& src); + +template +struct vec_t { + SGL_HIP_INLINE float_t& operator[](size_t i); + SGL_HIP_INLINE const float_t& operator[](size_t i) const; + SGL_HIP_INLINE float_t* ptr(); + + SGL_HIP_INLINE void load(const float_t* ptr); + SGL_HIP_INLINE void store(float_t* ptr) const; + + template + SGL_HIP_INLINE void cast_from(const vec_t& src); + template + SGL_HIP_INLINE void cast_load(const T* ptr); + template + SGL_HIP_INLINE void cast_store(T* ptr) const; +}; + +} // namespace sgl_hip + +// **** impl ***** + +namespace sgl_hip { + +template +SGL_HIP_INLINE void cast_load_impl(vec_t& dst, const srcDtype* src_ptr) { + if constexpr (std::is_same::value) { + dst.load(src_ptr); + } else { + vec_t tmp; + tmp.load(src_ptr); + dst.cast_from(tmp); + } +} + +template +SGL_HIP_INLINE void cast_store_impl(dstDtype* dst_ptr, const vec_t& src) { + if constexpr (std::is_same::value) { + src.store(dst_ptr); + } else { + vec_t tmp; + tmp.cast_from(src); + tmp.store(dst_ptr); + } +} + +template +template +SGL_HIP_INLINE void vec_t::cast_load(const T* ptr) { + cast_load_impl(*this, ptr); +} + +template +template +SGL_HIP_INLINE void vec_t::cast_store(T* ptr) const { + cast_store_impl(ptr, *this); +} + +} // namespace sgl_hip + +#include "impl/hip_vec_bf16_impl.h" +#include "impl/hip_vec_fp32_impl.h" +#include "impl/hip_vec_half_impl.h" +#endif diff --git a/sgl-kernel/include/impl/hip_vec_bf16_impl.h b/sgl-kernel/include/impl/hip_vec_bf16_impl.h new file mode 100644 index 000000000..b783f3f43 --- /dev/null +++ b/sgl-kernel/include/impl/hip_vec_bf16_impl.h @@ -0,0 +1,177 @@ +#pragma once + +#if USE_ROCM + +#include +#include + +// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491) + +using nv_bfloat16 = __hip_bfloat16; +using nv_bfloat162 = __hip_bfloat162; + +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 make_bfloat162(const __hip_bfloat16 x, const __hip_bfloat16 y) { + __hip_bfloat162 t; + t.x = x; + t.y = y; + return t; +} + +namespace sgl_hip { + +// nv_bfloat16 x 1 +template <> +struct vec_t { + nv_bfloat16 data; + SGL_HIP_INLINE nv_bfloat16& operator[](size_t i) { + return ((nv_bfloat16*)(&data))[i]; + } + SGL_HIP_INLINE const nv_bfloat16& operator[](size_t i) const { + return ((const nv_bfloat16*)(&data))[i]; + } + SGL_HIP_INLINE nv_bfloat16* ptr() { + return reinterpret_cast(&data); + } + SGL_HIP_INLINE void load(const nv_bfloat16* ptr); + SGL_HIP_INLINE void store(nv_bfloat16* ptr) const; + template + SGL_HIP_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + SGL_HIP_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + SGL_HIP_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } +}; + +SGL_HIP_INLINE void vec_t::load(const nv_bfloat16* ptr) { + data = *ptr; +} + +SGL_HIP_INLINE void vec_t::store(nv_bfloat16* ptr) const { + *ptr = data; +} + +// nv_bfloat16 x 2 +template <> +struct vec_t { + nv_bfloat162 data; + + SGL_HIP_INLINE nv_bfloat16& operator[](size_t i) { + return ((nv_bfloat16*)(&data))[i]; + } + SGL_HIP_INLINE const nv_bfloat16& operator[](size_t i) const { + return ((const nv_bfloat16*)(&data))[i]; + } + SGL_HIP_INLINE nv_bfloat16* ptr() { + return reinterpret_cast(&data); + } + SGL_HIP_INLINE void load(const nv_bfloat16* ptr); + SGL_HIP_INLINE void store(nv_bfloat16* ptr) const; + template + SGL_HIP_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + SGL_HIP_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + SGL_HIP_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } +}; + +SGL_HIP_INLINE void vec_t::load(const nv_bfloat16* ptr) { + data = *((nv_bfloat162*)ptr); +} + +SGL_HIP_INLINE void vec_t::store(nv_bfloat16* ptr) const { + *((nv_bfloat162*)ptr) = data; +} + +template <> +struct vec_t { + uint2 data; + + SGL_HIP_INLINE nv_bfloat16& operator[](size_t i) { + return ((nv_bfloat16*)(&data))[i]; + } + SGL_HIP_INLINE const nv_bfloat16& operator[](size_t i) const { + return ((const nv_bfloat16*)(&data))[i]; + } + SGL_HIP_INLINE nv_bfloat16* ptr() { + return reinterpret_cast(&data); + } + SGL_HIP_INLINE void load(const nv_bfloat16* ptr); + SGL_HIP_INLINE void store(nv_bfloat16* ptr) const; + template + SGL_HIP_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + SGL_HIP_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + SGL_HIP_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } +}; + +SGL_HIP_INLINE void vec_t::load(const nv_bfloat16* ptr) { + data = *((uint2*)ptr); +} + +SGL_HIP_INLINE void vec_t::store(nv_bfloat16* ptr) const { + *((uint2*)ptr) = data; +} + +// nv_bfloat16 x 8 or more + +template +struct vec_t { + uint4 data[vec_size / 8]; + + SGL_HIP_INLINE nv_bfloat16& operator[](size_t i) { + return ((nv_bfloat16*)data)[i]; + } + SGL_HIP_INLINE const nv_bfloat16& operator[](size_t i) const { + return ((const nv_bfloat16*)data)[i]; + } + SGL_HIP_INLINE nv_bfloat16* ptr() { + return reinterpret_cast(&data); + } + SGL_HIP_INLINE void load(const nv_bfloat16* ptr) { +#pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + data[i] = ((uint4*)ptr)[i]; + } + } + SGL_HIP_INLINE void store(nv_bfloat16* ptr) const { +#pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4*)ptr)[i] = data[i]; + } + } + template + SGL_HIP_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + SGL_HIP_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + SGL_HIP_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } +}; + +} // namespace sgl_hip + +#endif diff --git a/sgl-kernel/include/impl/hip_vec_fp32_impl.h b/sgl-kernel/include/impl/hip_vec_fp32_impl.h new file mode 100644 index 000000000..97cba6320 --- /dev/null +++ b/sgl-kernel/include/impl/hip_vec_fp32_impl.h @@ -0,0 +1,129 @@ +#pragma once + +#if USE_ROCM + +#include + +// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491) + +namespace sgl_hip { + +template <> +struct vec_t { + float data; + + SGL_HIP_INLINE float& operator[](size_t i) { + return ((float*)(&data))[i]; + } + SGL_HIP_INLINE const float& operator[](size_t i) const { + return ((const float*)(&data))[i]; + } + SGL_HIP_INLINE float* ptr() { + return reinterpret_cast(&data); + } + SGL_HIP_INLINE void load(const float* ptr); + SGL_HIP_INLINE void store(float* ptr) const; + template + SGL_HIP_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + SGL_HIP_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + SGL_HIP_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } +}; + +SGL_HIP_INLINE void vec_t::load(const float* ptr) { + data = *ptr; +} + +SGL_HIP_INLINE void vec_t::store(float* ptr) const { + *ptr = data; +} + +// float x 2 + +template <> +struct vec_t { + float2 data; + + SGL_HIP_INLINE float& operator[](size_t i) { + return ((float*)(&data))[i]; + } + SGL_HIP_INLINE const float& operator[](size_t i) const { + return ((const float*)(&data))[i]; + } + SGL_HIP_INLINE float* ptr() { + return reinterpret_cast(&data); + } + SGL_HIP_INLINE void load(const float* ptr); + SGL_HIP_INLINE void store(float* ptr) const; + template + SGL_HIP_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + SGL_HIP_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + SGL_HIP_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } +}; + +SGL_HIP_INLINE void vec_t::load(const float* ptr) { + data = *((float2*)ptr); +} + +SGL_HIP_INLINE void vec_t::store(float* ptr) const { + *((float2*)ptr) = data; +} + +// float x 4 or more +template +struct vec_t { + float4 data[vec_size / 4]; + + SGL_HIP_INLINE float& operator[](size_t i) { + return ((float*)(data))[i]; + } + SGL_HIP_INLINE const float& operator[](size_t i) const { + return ((const float*)(data))[i]; + } + SGL_HIP_INLINE float* ptr() { + return reinterpret_cast(&data); + } + SGL_HIP_INLINE void load(const float* ptr) { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + data[i] = ((float4*)ptr)[i]; + } + } + SGL_HIP_INLINE void store(float* ptr) const { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4*)ptr)[i] = data[i]; + } + } + template + SGL_HIP_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + SGL_HIP_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + SGL_HIP_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } +}; + +} // namespace sgl_hip + +#endif diff --git a/sgl-kernel/include/impl/hip_vec_half_impl.h b/sgl-kernel/include/impl/hip_vec_half_impl.h new file mode 100644 index 000000000..767b9c62f --- /dev/null +++ b/sgl-kernel/include/impl/hip_vec_half_impl.h @@ -0,0 +1,172 @@ +#pragma once + +#if USE_ROCM + +#include +#include + +// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491) + +using half = __half; +using half2 = __half2; + +namespace sgl_hip { + +// half x 1 +template <> +struct vec_t { + half data; + + SGL_HIP_INLINE half& operator[](size_t i) { + return ((half*)(&data))[i]; + } + SGL_HIP_INLINE const half& operator[](size_t i) const { + return ((const half*)(&data))[i]; + } + SGL_HIP_INLINE half* ptr() { + return reinterpret_cast(&data); + } + SGL_HIP_INLINE void load(const half* ptr); + SGL_HIP_INLINE void store(half* ptr) const; + template + SGL_HIP_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + SGL_HIP_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + SGL_HIP_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } +}; + +SGL_HIP_INLINE void vec_t::load(const half* ptr) { + data = *ptr; +} + +SGL_HIP_INLINE void vec_t::store(half* ptr) const { + *ptr = data; +} + +// half x 2 +template <> +struct vec_t { + half2 data; + + SGL_HIP_INLINE half& operator[](size_t i) { + return ((half*)(&data))[i]; + } + SGL_HIP_INLINE const half& operator[](size_t i) const { + return ((const half*)(&data))[i]; + } + SGL_HIP_INLINE half* ptr() { + return reinterpret_cast(&data); + } + SGL_HIP_INLINE void load(const half* ptr); + SGL_HIP_INLINE void store(half* ptr) const; + template + SGL_HIP_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + SGL_HIP_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + SGL_HIP_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } +}; + +SGL_HIP_INLINE void vec_t::load(const half* ptr) { + data = *((half2*)ptr); +} + +SGL_HIP_INLINE void vec_t::store(half* ptr) const { + *((half2*)ptr) = data; +} + +// half x 4 + +template <> +struct vec_t { + uint2 data; + + SGL_HIP_INLINE half& operator[](size_t i) { + return ((half*)(&data))[i]; + } + SGL_HIP_INLINE const half& operator[](size_t i) const { + return ((const half*)(&data))[i]; + } + SGL_HIP_INLINE half* ptr() { + return reinterpret_cast(&data); + } + SGL_HIP_INLINE void load(const half* ptr); + SGL_HIP_INLINE void store(half* ptr) const; + template + SGL_HIP_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + SGL_HIP_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + SGL_HIP_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } +}; + +SGL_HIP_INLINE void vec_t::load(const half* ptr) { + data = *((uint2*)ptr); +} + +SGL_HIP_INLINE void vec_t::store(half* ptr) const { + *((uint2*)ptr) = data; +} + +// half x 8 or more + +template +struct vec_t { + uint4 data[vec_size / 8]; + + SGL_HIP_INLINE half& operator[](size_t i) { + return ((half*)data)[i]; + } + SGL_HIP_INLINE const half& operator[](size_t i) const { + return ((const half*)data)[i]; + } + SGL_HIP_INLINE half* ptr() { + return reinterpret_cast(&data); + } + SGL_HIP_INLINE void load(const half* ptr) { +#pragma unroll + for (size_t i = 0; i < vec_size / 8; ++i) { + data[i] = ((uint4*)ptr)[i]; + } + } + SGL_HIP_INLINE void store(half* ptr) const { +#pragma unroll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4*)ptr)[i] = data[i]; + } + } + template + SGL_HIP_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + SGL_HIP_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + SGL_HIP_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } +}; + +} // namespace sgl_hip +#endif diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index ffd240a04..ca8276050 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -138,9 +138,10 @@ void sgl_fused_add_rmsnorm( torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps, bool enable_pdl); void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl); void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, bool enable_pdl); -void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); -void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); -void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); +void silu_and_mul(at::Tensor& out, at::Tensor& input); +void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input); +void gelu_and_mul(at::Tensor& out, at::Tensor& input); + void apply_rope_pos_ids_cos_sin_cache( at::Tensor q, at::Tensor k, @@ -151,6 +152,9 @@ void apply_rope_pos_ids_cos_sin_cache( bool interleave, int64_t cuda_stream); +#ifdef USE_ROCM +void gelu_quick(at::Tensor& out, const at::Tensor& input); +#endif /* * From csrc/gemm */ diff --git a/sgl-kernel/include/utils.h b/sgl-kernel/include/utils.h index 1054dbc52..d7d0d5d1f 100644 --- a/sgl-kernel/include/utils.h +++ b/sgl-kernel/include/utils.h @@ -19,7 +19,20 @@ limitations under the License. #include #include -#include +#ifdef USE_ROCM +// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491) +#define _DISPATCH_CASE_F16(c_type, ...) \ + case at::ScalarType::Half: { \ + using c_type = __half; \ + return __VA_ARGS__(); \ + } + +#define _DISPATCH_CASE_BF16(c_type, ...) \ + case at::ScalarType::BFloat16: { \ + using c_type = __hip_bfloat16; \ + return __VA_ARGS__(); \ + } +#endif // USE_ROCM #ifndef USE_ROCM // Adapt from FlashInfer @@ -31,7 +44,7 @@ limitations under the License. } #else #define _DISPATCH_CASE_F16(c_type, ...) -#endif +#endif // FLASHINFER_ENABLE_F16 #ifdef FLASHINFER_ENABLE_BF16 #define _DISPATCH_CASE_BF16(c_type, ...) \ @@ -41,7 +54,7 @@ limitations under the License. } #else #define _DISPATCH_CASE_BF16(c_type, ...) -#endif +#endif // FLASHINFER_ENABLE_BF16 #ifdef FLASHINFER_ENABLE_FP8_E4M3 #define _DISPATCH_CASE_FP8_E4M3(c_type, ...) \ @@ -51,7 +64,7 @@ limitations under the License. } #else #define _DISPATCH_CASE_FP8_E4M3(c_type, ...) -#endif +#endif // FLASHINFER_ENABLE_FP8_E4M3 #ifdef FLASHINFER_ENABLE_FP8_E5M2 #define _DISPATCH_CASE_FP8_E5M2(c_type, ...) \ @@ -61,7 +74,7 @@ limitations under the License. } #else #define _DISPATCH_CASE_FP8_E5M2(c_type, ...) -#endif +#endif // FLASHINFER_ENABLE_FP8_E5M2 #define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \ [&]() -> bool { \ @@ -197,7 +210,7 @@ inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { inline bool is_float8_tensor(const at::Tensor& tensor) { return tensor.scalar_type() == at::ScalarType::Float8_e4m3fn || tensor.scalar_type() == at::ScalarType::Float8_e5m2; } -#endif +#endif // USE_ROCM struct cuda_error : public std::runtime_error { /** @@ -267,7 +280,6 @@ inline bool getEnvEnablePDL() { #define SGLANG_SHFL_XOR_SYNC_WIDTH(mask, var, lane_mask, width) __shfl_xor((var), (lane_mask), (width)) #endif -#ifndef USE_ROCM #define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(pytorch_dtype, c_type, ...) \ [&]() -> bool { \ switch (pytorch_dtype) { \ @@ -284,7 +296,6 @@ inline bool getEnvEnablePDL() { return false; \ } \ }() -#endif #define DISPATCH_CASE_INTEGRAL_TYPES(...) \ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ @@ -297,52 +308,99 @@ inline bool getEnvEnablePDL() { AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) #define CEILDIV(x, y) (((x) + (y) - 1) / (y)) + +#ifndef USE_ROCM #define WARP_SIZE 32 +#else +#define WARP_SIZE warpSize // 64 +#endif + +#if defined(__HIP_PLATFORM_AMD__) + +#include "hip_math_def.h" +#include "hip_vec_dtypes.h" + +#else + +template +__device__ __forceinline__ float castToFloat(srcDtype val) { + return static_cast(val); +} + +template +__device__ __forceinline__ dstDtype castFromFloat(float val) { + return static_cast(val); +} + +#endif + +// add FP8 support #ifndef USE_ROCM #include using FP8_TYPE = c10::Float8_e4m3fn; C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits::max(); -#else -#include +#else // USE_ROCM + +#if HIP_FP8_TYPE_FNUZ +#include using FP8_TYPE = c10::Float8_e4m3fnuz; constexpr auto FP8_E4M3_MAX = 224.0f; -#endif +#else +#if HIP_FP8_TYPE_E4M3 +#include +using FP8_TYPE = c10::Float8_e4m3fn; +C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits::max(); +#else +#error "fp8 is not supported in this processor (arch < gfx942)." +#endif // HIP_FP8_TYPE_E4M3 +#endif // HIP_FP8_TYPE_FNUZ +#endif // USE_ROCM + +#define FULL_MASK 0xffffffff -#ifndef USE_ROCM __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { +#ifndef USE_ROCM float old; old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); return old; +#else + int* addr_as_i = (int*)addr; + int old = *addr_as_i, assumed; + do { + assumed = old; + old = atomicCAS(addr_as_i, assumed, __float_as_int(fmaxf(value, __int_as_float(assumed)))); + } while (assumed != old); + return __int_as_float(old); +#endif } -__device__ __forceinline__ float warpReduceMax(float max_value) { - max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 16)); - max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 8)); - max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 4)); - max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 2)); - max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 1)); - return max_value; +__device__ __forceinline__ float warpReduceMax(float value) { + value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 16)); + value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 8)); + value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 4)); + value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 2)); + value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 1)); + return value; } -__device__ __forceinline__ float blockReduceMax(float max_value) { +__device__ __forceinline__ float blockReduceMax(float value) { static __shared__ float warpLevelMaxs[WARP_SIZE]; const int laneId = threadIdx.x % WARP_SIZE; const int warpId = threadIdx.x / WARP_SIZE; - max_value = warpReduceMax(max_value); + value = warpReduceMax(value); - if (laneId == 0) warpLevelMaxs[warpId] = max_value; + if (laneId == 0) warpLevelMaxs[warpId] = value; __syncthreads(); - max_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0; - if (warpId == 0) max_value = warpReduceMax(max_value); + value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0; + if (warpId == 0) value = warpReduceMax(value); - return max_value; + return value; } -#endif // Pads to a multiple of `alignment` rows. inline torch::Tensor pad_tensor(const torch::Tensor& tensor, int64_t alignment = 4, bool is_column_major = false) { diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 5cecfc3c0..2a4656aea 100755 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -31,6 +31,10 @@ from sgl_kernel.elementwise import ( silu_and_mul, ) from sgl_kernel.fused_moe import fused_marlin_moe + +if torch.version.hip is not None: + from sgl_kernel.elementwise import gelu_quick + from sgl_kernel.gemm import ( awq_dequantize, bmm_fp8, diff --git a/sgl-kernel/python/sgl_kernel/elementwise.py b/sgl-kernel/python/sgl_kernel/elementwise.py index 0e2bbc990..01ee71860 100644 --- a/sgl-kernel/python/sgl_kernel/elementwise.py +++ b/sgl-kernel/python/sgl_kernel/elementwise.py @@ -179,7 +179,7 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: device=input.device, dtype=input.dtype, ) - torch.ops.sgl_kernel.silu_and_mul.default(out, input, get_cuda_stream()) + torch.ops.sgl_kernel.silu_and_mul.default(out, input) return out @@ -194,7 +194,7 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te device=input.device, dtype=input.dtype, ) - torch.ops.sgl_kernel.gelu_tanh_and_mul.default(out, input, get_cuda_stream()) + torch.ops.sgl_kernel.gelu_tanh_and_mul.default(out, input) return out @@ -209,10 +209,34 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: device=input.device, dtype=input.dtype, ) - torch.ops.sgl_kernel.gelu_and_mul.default(out, input, get_cuda_stream()) + torch.ops.sgl_kernel.gelu_and_mul.default(out, input) return out +if torch.version.hip is not None: + + def gelu_quick(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: + """ + Quick-GELU: y = x * sigmoid(1.702 * x) + + The CUDA/HIP kernel uses 128-bit (16-byte) vector loads & stores, + so the last-dimension byte length must be a multiple of 16 bytes. + """ + if input.shape[-1] * input.dtype.itemsize % 16 != 0: + raise ValueError( + f"The last dimension ({input.shape[-1]}) x itemsize " + f"({input.dtype.itemsize}) must be a multiple of 16 bytes." + ) + + if out is not None: + assert input.shape == out.shape, f"{input.shape} != {out.shape}" + else: + out = torch.empty_like(input) + + torch.ops.sgl_kernel.gelu_quick(out, input) + return out + + def apply_rope_with_cos_sin_cache_inplace( positions: torch.Tensor, query: torch.Tensor, diff --git a/sgl-kernel/setup_rocm.py b/sgl-kernel/setup_rocm.py index a814b8196..47f59071f 100644 --- a/sgl-kernel/setup_rocm.py +++ b/sgl-kernel/setup_rocm.py @@ -36,16 +36,18 @@ def _get_version(): operator_namespace = "sgl_kernel" include_dirs = [ root / "include", + root / "include" / "impl", root / "csrc", ] sources = [ "csrc/allreduce/custom_all_reduce.hip", "csrc/allreduce/quick_all_reduce.cu", + "csrc/elementwise/activation.cu", "csrc/moe/moe_align_kernel.cu", "csrc/moe/moe_topk_softmax_kernels.cu", - "csrc/torch_extension_rocm.cc", "csrc/speculative/eagle_utils.cu", + "csrc/torch_extension_rocm.cc", ] cxx_flags = ["-O3"] @@ -69,6 +71,7 @@ if amdgpu_target not in ["gfx942", "gfx950"]: ) sys.exit(1) + hipcc_flags = [ "-DNDEBUG", f"-DOPERATOR_NAMESPACE={operator_namespace}",