adapt to sglang v0.5.2rc1 on dcu

This commit is contained in:
maxiao
2025-09-04 15:56:33 +08:00
commit 909abb58f5
2320 changed files with 489411 additions and 0 deletions

View File

@@ -0,0 +1,87 @@
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#pragma once
#include "utils.h"
#define kBitsToLoad 128
#define kBytesToLoad (kBitsToLoad / 8)
// Adapted from
// [flashinfer::activation::act_and_mul_kernel](https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/include/flashinfer/activation.cuh#L29)
namespace sgl_hip {
namespace activation {
template <typename T, T (*Activation)(const T&)>
__global__ void act_and_mul_kernel(T* __restrict__ out, const T* __restrict__ input, const int d) {
constexpr uint32_t vec_size = kBytesToLoad / sizeof(T);
const int64_t token_idx = blockIdx.x;
const int64_t thread_idx = threadIdx.x;
const int64_t stride = blockDim.x;
const int64_t offset = token_idx * 2 * d;
#pragma unroll 1
for (uint32_t idx = thread_idx; idx < d / vec_size; idx += stride) {
sgl_hip::vec_t<T, vec_size> x_vec, y_vec, out_vec;
x_vec.cast_load(input + offset + idx * vec_size);
y_vec.cast_load(input + offset + d + idx * vec_size);
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
out_vec[i] = Activation(x_vec[i]) * y_vec[i];
}
out_vec.cast_store(out + token_idx * d + idx * vec_size);
}
const int64_t remaining_offset = d - d % (stride * vec_size);
// process the remaining elements
#pragma unroll 1
for (int64_t idx = thread_idx; idx < d % (stride * vec_size); idx += stride) {
T x = input[offset + remaining_offset + idx], y = input[offset + remaining_offset + d + idx];
out[token_idx * d + remaining_offset + idx] = Activation(x) * y;
}
}
template <typename T, T (*Activation)(const T&)>
__global__ void act_only_kernel(T* __restrict__ out, const T* __restrict__ input, const int d) {
constexpr uint32_t vec_size = kBytesToLoad / sizeof(T);
const int64_t token_idx = blockIdx.x;
const int64_t thread_idx = threadIdx.x;
const int64_t stride = blockDim.x;
const int64_t offset = token_idx * d;
#pragma unroll 1
for (uint32_t idx = thread_idx; idx < d / vec_size; idx += stride) {
sgl_hip::vec_t<T, vec_size> x_vec, y_vec, out_vec;
x_vec.cast_load(input + offset + idx * vec_size);
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
out_vec[i] = Activation(x_vec[i]);
}
out_vec.cast_store(out + token_idx * d + idx * vec_size);
}
const int64_t remaining_offset = d - d % (stride * vec_size);
// process the remaining elements
#pragma unroll 1
for (int64_t idx = thread_idx; idx < d % (stride * vec_size); idx += stride) {
T x = input[offset + remaining_offset + idx];
out[token_idx * d + remaining_offset + idx] = Activation(x);
}
}
} // namespace activation
} // namespace sgl_hip

View File

@@ -0,0 +1,94 @@
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#pragma once
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
#include <hip/hip_common.h>
#include <hip/hip_fp16.h>
// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491)
namespace amdgpu {
template <typename T>
__forceinline__ __device__ T shfl_xor_sync(unsigned mask, T var, int laneMask, int width = warpSize);
template <typename srcDtype, typename destDtype>
__forceinline__ __device__ destDtype cast(srcDtype val);
// specialization
template <>
__forceinline__ __device__ float shfl_xor_sync(unsigned mask, float var, int laneMask, int width) {
return __shfl_xor(var, laneMask, width);
}
template <>
__forceinline__ __device__ int shfl_xor_sync(unsigned mask, int var, int laneMask, int width) {
return __shfl_xor(var, laneMask, width);
}
template <>
__forceinline__ __device__ float cast(float val) {
return val;
}
template <>
__forceinline__ __device__ float cast(__half val) {
return __half2float(val);
}
template <>
__forceinline__ __device__ float cast(__hip_bfloat16 val) {
return __bfloat162float(val);
}
template <>
__forceinline__ __device__ __half cast(float fval) {
return __float2half(fval);
}
template <>
__forceinline__ __device__ __hip_bfloat16 cast(float fval) {
return __float2bfloat16(fval);
}
} // namespace amdgpu
template <typename T>
__forceinline__ __device__ T __shfl_xor_sync(unsigned mask, T var, int laneMask, int width = warpSize) {
return amdgpu::shfl_xor_sync(mask, var, laneMask, width);
}
template <typename srcDtype>
__device__ __forceinline__ float castToFloat(srcDtype val) {
return amdgpu::cast<srcDtype, float>(val);
}
template <typename dstDtype>
__device__ __forceinline__ dstDtype castFromFloat(float val) {
return amdgpu::cast<float, dstDtype>(val);
}
// operator overload to support flashinfer
__host__ __device__ __forceinline__ __half operator*(const __half& x, const __half& y) {
__half h_x = x;
__half h_y = y;
return __hmul(h_x, h_y);
}
#endif

View File

@@ -0,0 +1,101 @@
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#pragma once
#if USE_ROCM
#include <hip/hip_bf16.h>
#include <hip/hip_common.h>
#include <hip/hip_fp16.h>
// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491)d
#define SGL_HIP_INLINE inline __attribute__((always_inline)) __device__
namespace sgl_hip {
template <typename float_t, size_t vec_size>
struct vec_t;
template <typename srcDtype, typename dstDtype, size_t vec_size>
SGL_HIP_INLINE void cast_load_impl(vec_t<dstDtype, vec_size>& dst, const srcDtype* src);
template <typename srcDtype, typename dstDtype, size_t vec_size>
SGL_HIP_INLINE void cast_store_impl(dstDtype* dst_ptr, const vec_t<srcDtype, vec_size>& src);
template <typename float_t, size_t vec_size>
struct vec_t {
SGL_HIP_INLINE float_t& operator[](size_t i);
SGL_HIP_INLINE const float_t& operator[](size_t i) const;
SGL_HIP_INLINE float_t* ptr();
SGL_HIP_INLINE void load(const float_t* ptr);
SGL_HIP_INLINE void store(float_t* ptr) const;
template <typename T>
SGL_HIP_INLINE void cast_from(const vec_t<T, vec_size>& src);
template <typename T>
SGL_HIP_INLINE void cast_load(const T* ptr);
template <typename T>
SGL_HIP_INLINE void cast_store(T* ptr) const;
};
} // namespace sgl_hip
// **** impl *****
namespace sgl_hip {
template <typename srcDtype, typename dstDtype, size_t vec_size>
SGL_HIP_INLINE void cast_load_impl(vec_t<dstDtype, vec_size>& dst, const srcDtype* src_ptr) {
if constexpr (std::is_same<srcDtype, dstDtype>::value) {
dst.load(src_ptr);
} else {
vec_t<srcDtype, vec_size> tmp;
tmp.load(src_ptr);
dst.cast_from(tmp);
}
}
template <typename srcDtype, typename dstDtype, size_t vec_size>
SGL_HIP_INLINE void cast_store_impl(dstDtype* dst_ptr, const vec_t<srcDtype, vec_size>& src) {
if constexpr (std::is_same<srcDtype, dstDtype>::value) {
src.store(dst_ptr);
} else {
vec_t<dstDtype, vec_size> tmp;
tmp.cast_from(src);
tmp.store(dst_ptr);
}
}
template <typename float_t, size_t vec_size>
template <typename T>
SGL_HIP_INLINE void vec_t<float_t, vec_size>::cast_load(const T* ptr) {
cast_load_impl(*this, ptr);
}
template <typename float_t, size_t vec_size>
template <typename T>
SGL_HIP_INLINE void vec_t<float_t, vec_size>::cast_store(T* ptr) const {
cast_store_impl(ptr, *this);
}
} // namespace sgl_hip
#include "impl/hip_vec_bf16_impl.h"
#include "impl/hip_vec_fp32_impl.h"
#include "impl/hip_vec_half_impl.h"
#endif

View File

@@ -0,0 +1,177 @@
#pragma once
#if USE_ROCM
#include <hip/hip_bf16.h>
#include <hip/hip_common.h>
// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491)
using nv_bfloat16 = __hip_bfloat16;
using nv_bfloat162 = __hip_bfloat162;
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 make_bfloat162(const __hip_bfloat16 x, const __hip_bfloat16 y) {
__hip_bfloat162 t;
t.x = x;
t.y = y;
return t;
}
namespace sgl_hip {
// nv_bfloat16 x 1
template <>
struct vec_t<nv_bfloat16, 1> {
nv_bfloat16 data;
SGL_HIP_INLINE nv_bfloat16& operator[](size_t i) {
return ((nv_bfloat16*)(&data))[i];
}
SGL_HIP_INLINE const nv_bfloat16& operator[](size_t i) const {
return ((const nv_bfloat16*)(&data))[i];
}
SGL_HIP_INLINE nv_bfloat16* ptr() {
return reinterpret_cast<nv_bfloat16*>(&data);
}
SGL_HIP_INLINE void load(const nv_bfloat16* ptr);
SGL_HIP_INLINE void store(nv_bfloat16* ptr) const;
template <typename T>
SGL_HIP_INLINE void cast_from(const vec_t<T, 1>& src) {
cast_from_impl(*this, src);
}
template <typename T>
SGL_HIP_INLINE void cast_load(const T* ptr) {
cast_load_impl(*this, ptr);
}
template <typename T>
SGL_HIP_INLINE void cast_store(T* ptr) const {
cast_store_impl(ptr, *this);
}
};
SGL_HIP_INLINE void vec_t<nv_bfloat16, 1>::load(const nv_bfloat16* ptr) {
data = *ptr;
}
SGL_HIP_INLINE void vec_t<nv_bfloat16, 1>::store(nv_bfloat16* ptr) const {
*ptr = data;
}
// nv_bfloat16 x 2
template <>
struct vec_t<nv_bfloat16, 2> {
nv_bfloat162 data;
SGL_HIP_INLINE nv_bfloat16& operator[](size_t i) {
return ((nv_bfloat16*)(&data))[i];
}
SGL_HIP_INLINE const nv_bfloat16& operator[](size_t i) const {
return ((const nv_bfloat16*)(&data))[i];
}
SGL_HIP_INLINE nv_bfloat16* ptr() {
return reinterpret_cast<nv_bfloat16*>(&data);
}
SGL_HIP_INLINE void load(const nv_bfloat16* ptr);
SGL_HIP_INLINE void store(nv_bfloat16* ptr) const;
template <typename T>
SGL_HIP_INLINE void cast_from(const vec_t<T, 2>& src) {
cast_from_impl(*this, src);
}
template <typename T>
SGL_HIP_INLINE void cast_load(const T* ptr) {
cast_load_impl(*this, ptr);
}
template <typename T>
SGL_HIP_INLINE void cast_store(T* ptr) const {
cast_store_impl(ptr, *this);
}
};
SGL_HIP_INLINE void vec_t<nv_bfloat16, 2>::load(const nv_bfloat16* ptr) {
data = *((nv_bfloat162*)ptr);
}
SGL_HIP_INLINE void vec_t<nv_bfloat16, 2>::store(nv_bfloat16* ptr) const {
*((nv_bfloat162*)ptr) = data;
}
template <>
struct vec_t<nv_bfloat16, 4> {
uint2 data;
SGL_HIP_INLINE nv_bfloat16& operator[](size_t i) {
return ((nv_bfloat16*)(&data))[i];
}
SGL_HIP_INLINE const nv_bfloat16& operator[](size_t i) const {
return ((const nv_bfloat16*)(&data))[i];
}
SGL_HIP_INLINE nv_bfloat16* ptr() {
return reinterpret_cast<nv_bfloat16*>(&data);
}
SGL_HIP_INLINE void load(const nv_bfloat16* ptr);
SGL_HIP_INLINE void store(nv_bfloat16* ptr) const;
template <typename T>
SGL_HIP_INLINE void cast_from(const vec_t<T, 4>& src) {
cast_from_impl(*this, src);
}
template <typename T>
SGL_HIP_INLINE void cast_load(const T* ptr) {
cast_load_impl(*this, ptr);
}
template <typename T>
SGL_HIP_INLINE void cast_store(T* ptr) const {
cast_store_impl(ptr, *this);
}
};
SGL_HIP_INLINE void vec_t<nv_bfloat16, 4>::load(const nv_bfloat16* ptr) {
data = *((uint2*)ptr);
}
SGL_HIP_INLINE void vec_t<nv_bfloat16, 4>::store(nv_bfloat16* ptr) const {
*((uint2*)ptr) = data;
}
// nv_bfloat16 x 8 or more
template <size_t vec_size>
struct vec_t<nv_bfloat16, vec_size> {
uint4 data[vec_size / 8];
SGL_HIP_INLINE nv_bfloat16& operator[](size_t i) {
return ((nv_bfloat16*)data)[i];
}
SGL_HIP_INLINE const nv_bfloat16& operator[](size_t i) const {
return ((const nv_bfloat16*)data)[i];
}
SGL_HIP_INLINE nv_bfloat16* ptr() {
return reinterpret_cast<nv_bfloat16*>(&data);
}
SGL_HIP_INLINE void load(const nv_bfloat16* ptr) {
#pragma unoll
for (size_t i = 0; i < vec_size / 8; ++i) {
data[i] = ((uint4*)ptr)[i];
}
}
SGL_HIP_INLINE void store(nv_bfloat16* ptr) const {
#pragma unoll
for (size_t i = 0; i < vec_size / 8; ++i) {
((uint4*)ptr)[i] = data[i];
}
}
template <typename T>
SGL_HIP_INLINE void cast_from(const vec_t<T, vec_size>& src) {
cast_from_impl(*this, src);
}
template <typename T>
SGL_HIP_INLINE void cast_load(const T* ptr) {
cast_load_impl(*this, ptr);
}
template <typename T>
SGL_HIP_INLINE void cast_store(T* ptr) const {
cast_store_impl(ptr, *this);
}
};
} // namespace sgl_hip
#endif

View File

@@ -0,0 +1,129 @@
#pragma once
#if USE_ROCM
#include <hip/hip_common.h>
// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491)
namespace sgl_hip {
template <>
struct vec_t<float, 1> {
float data;
SGL_HIP_INLINE float& operator[](size_t i) {
return ((float*)(&data))[i];
}
SGL_HIP_INLINE const float& operator[](size_t i) const {
return ((const float*)(&data))[i];
}
SGL_HIP_INLINE float* ptr() {
return reinterpret_cast<float*>(&data);
}
SGL_HIP_INLINE void load(const float* ptr);
SGL_HIP_INLINE void store(float* ptr) const;
template <typename T>
SGL_HIP_INLINE void cast_from(const vec_t<T, 1>& src) {
cast_from_impl(*this, src);
}
template <typename T>
SGL_HIP_INLINE void cast_load(const T* ptr) {
cast_load_impl(*this, ptr);
}
template <typename T>
SGL_HIP_INLINE void cast_store(T* ptr) const {
cast_store_impl(ptr, *this);
}
};
SGL_HIP_INLINE void vec_t<float, 1>::load(const float* ptr) {
data = *ptr;
}
SGL_HIP_INLINE void vec_t<float, 1>::store(float* ptr) const {
*ptr = data;
}
// float x 2
template <>
struct vec_t<float, 2> {
float2 data;
SGL_HIP_INLINE float& operator[](size_t i) {
return ((float*)(&data))[i];
}
SGL_HIP_INLINE const float& operator[](size_t i) const {
return ((const float*)(&data))[i];
}
SGL_HIP_INLINE float* ptr() {
return reinterpret_cast<float*>(&data);
}
SGL_HIP_INLINE void load(const float* ptr);
SGL_HIP_INLINE void store(float* ptr) const;
template <typename T>
SGL_HIP_INLINE void cast_from(const vec_t<T, 2>& src) {
cast_from_impl(*this, src);
}
template <typename T>
SGL_HIP_INLINE void cast_load(const T* ptr) {
cast_load_impl(*this, ptr);
}
template <typename T>
SGL_HIP_INLINE void cast_store(T* ptr) const {
cast_store_impl(ptr, *this);
}
};
SGL_HIP_INLINE void vec_t<float, 2>::load(const float* ptr) {
data = *((float2*)ptr);
}
SGL_HIP_INLINE void vec_t<float, 2>::store(float* ptr) const {
*((float2*)ptr) = data;
}
// float x 4 or more
template <size_t vec_size>
struct vec_t<float, vec_size> {
float4 data[vec_size / 4];
SGL_HIP_INLINE float& operator[](size_t i) {
return ((float*)(data))[i];
}
SGL_HIP_INLINE const float& operator[](size_t i) const {
return ((const float*)(data))[i];
}
SGL_HIP_INLINE float* ptr() {
return reinterpret_cast<float*>(&data);
}
SGL_HIP_INLINE void load(const float* ptr) {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
data[i] = ((float4*)ptr)[i];
}
}
SGL_HIP_INLINE void store(float* ptr) const {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
((float4*)ptr)[i] = data[i];
}
}
template <typename T>
SGL_HIP_INLINE void cast_from(const vec_t<T, vec_size>& src) {
cast_from_impl(*this, src);
}
template <typename T>
SGL_HIP_INLINE void cast_load(const T* ptr) {
cast_load_impl(*this, ptr);
}
template <typename T>
SGL_HIP_INLINE void cast_store(T* ptr) const {
cast_store_impl(ptr, *this);
}
};
} // namespace sgl_hip
#endif

View File

@@ -0,0 +1,172 @@
#pragma once
#if USE_ROCM
#include <hip/hip_common.h>
#include <hip/hip_fp16.h>
// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491)
using half = __half;
using half2 = __half2;
namespace sgl_hip {
// half x 1
template <>
struct vec_t<half, 1> {
half data;
SGL_HIP_INLINE half& operator[](size_t i) {
return ((half*)(&data))[i];
}
SGL_HIP_INLINE const half& operator[](size_t i) const {
return ((const half*)(&data))[i];
}
SGL_HIP_INLINE half* ptr() {
return reinterpret_cast<half*>(&data);
}
SGL_HIP_INLINE void load(const half* ptr);
SGL_HIP_INLINE void store(half* ptr) const;
template <typename T>
SGL_HIP_INLINE void cast_from(const vec_t<T, 1>& src) {
cast_from_impl(*this, src);
}
template <typename T>
SGL_HIP_INLINE void cast_load(const T* ptr) {
cast_load_impl(*this, ptr);
}
template <typename T>
SGL_HIP_INLINE void cast_store(T* ptr) const {
cast_store_impl(ptr, *this);
}
};
SGL_HIP_INLINE void vec_t<half, 1>::load(const half* ptr) {
data = *ptr;
}
SGL_HIP_INLINE void vec_t<half, 1>::store(half* ptr) const {
*ptr = data;
}
// half x 2
template <>
struct vec_t<half, 2> {
half2 data;
SGL_HIP_INLINE half& operator[](size_t i) {
return ((half*)(&data))[i];
}
SGL_HIP_INLINE const half& operator[](size_t i) const {
return ((const half*)(&data))[i];
}
SGL_HIP_INLINE half* ptr() {
return reinterpret_cast<half*>(&data);
}
SGL_HIP_INLINE void load(const half* ptr);
SGL_HIP_INLINE void store(half* ptr) const;
template <typename T>
SGL_HIP_INLINE void cast_from(const vec_t<T, 2>& src) {
cast_from_impl(*this, src);
}
template <typename T>
SGL_HIP_INLINE void cast_load(const T* ptr) {
cast_load_impl(*this, ptr);
}
template <typename T>
SGL_HIP_INLINE void cast_store(T* ptr) const {
cast_store_impl(ptr, *this);
}
};
SGL_HIP_INLINE void vec_t<half, 2>::load(const half* ptr) {
data = *((half2*)ptr);
}
SGL_HIP_INLINE void vec_t<half, 2>::store(half* ptr) const {
*((half2*)ptr) = data;
}
// half x 4
template <>
struct vec_t<half, 4> {
uint2 data;
SGL_HIP_INLINE half& operator[](size_t i) {
return ((half*)(&data))[i];
}
SGL_HIP_INLINE const half& operator[](size_t i) const {
return ((const half*)(&data))[i];
}
SGL_HIP_INLINE half* ptr() {
return reinterpret_cast<half*>(&data);
}
SGL_HIP_INLINE void load(const half* ptr);
SGL_HIP_INLINE void store(half* ptr) const;
template <typename T>
SGL_HIP_INLINE void cast_from(const vec_t<T, 4>& src) {
cast_from_impl(*this, src);
}
template <typename T>
SGL_HIP_INLINE void cast_load(const T* ptr) {
cast_load_impl(*this, ptr);
}
template <typename T>
SGL_HIP_INLINE void cast_store(T* ptr) const {
cast_store_impl(ptr, *this);
}
};
SGL_HIP_INLINE void vec_t<half, 4>::load(const half* ptr) {
data = *((uint2*)ptr);
}
SGL_HIP_INLINE void vec_t<half, 4>::store(half* ptr) const {
*((uint2*)ptr) = data;
}
// half x 8 or more
template <size_t vec_size>
struct vec_t<half, vec_size> {
uint4 data[vec_size / 8];
SGL_HIP_INLINE half& operator[](size_t i) {
return ((half*)data)[i];
}
SGL_HIP_INLINE const half& operator[](size_t i) const {
return ((const half*)data)[i];
}
SGL_HIP_INLINE half* ptr() {
return reinterpret_cast<half*>(&data);
}
SGL_HIP_INLINE void load(const half* ptr) {
#pragma unroll
for (size_t i = 0; i < vec_size / 8; ++i) {
data[i] = ((uint4*)ptr)[i];
}
}
SGL_HIP_INLINE void store(half* ptr) const {
#pragma unroll
for (size_t i = 0; i < vec_size / 8; ++i) {
((uint4*)ptr)[i] = data[i];
}
}
template <typename T>
SGL_HIP_INLINE void cast_from(const vec_t<T, vec_size>& src) {
cast_from_impl(*this, src);
}
template <typename T>
SGL_HIP_INLINE void cast_load(const T* ptr) {
cast_load_impl(*this, ptr);
}
template <typename T>
SGL_HIP_INLINE void cast_store(T* ptr) const {
cast_store_impl(ptr, *this);
}
};
} // namespace sgl_hip
#endif

View File

@@ -0,0 +1,20 @@
#include <torch/library.h>
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_LAST_DIM_CONTIGUOUS(x) \
TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimension")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \
CHECK_CUDA(x); \
CHECK_LAST_DIM_CONTIGUOUS(x)
#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b)

View File

@@ -0,0 +1,330 @@
#pragma once
// For TORCH_CHECK
#include <torch/library.h>
namespace sglang {
//
// ScalarType can represent a wide range of floating point and integer types,
// in particular it can be used to represent sub-byte data types (something
// that torch.dtype currently does not support).
//
// The type definitions on the Python side can be found in: vllm/scalar_type.py
// these type definitions should be kept up to date with any Python API changes
// here.
//
class ScalarType {
public:
enum NanRepr : uint8_t {
NAN_NONE = 0, // nans are not supported
NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s
NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s
NAN_REPR_ID_MAX
};
constexpr ScalarType(
uint8_t exponent,
uint8_t mantissa,
bool signed_,
int32_t bias,
bool finite_values_only = false,
NanRepr nan_repr = NAN_IEEE_754)
: exponent(exponent),
mantissa(mantissa),
signed_(signed_),
bias(bias),
finite_values_only(finite_values_only),
nan_repr(nan_repr) {};
static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) {
return ScalarType(0, size_bits - 1, true, bias);
}
static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) {
return ScalarType(0, size_bits, false, bias);
}
// IEEE 754 compliant floating point type
static constexpr ScalarType float_IEEE754(uint8_t exponent, uint8_t mantissa) {
TORCH_CHECK(mantissa > 0 && exponent > 0);
return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754);
}
// IEEE 754 non-compliant floating point type
static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa, bool finite_values_only, NanRepr nan_repr) {
TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr");
TORCH_CHECK(mantissa > 0 && exponent > 0);
TORCH_CHECK(
nan_repr != NAN_IEEE_754,
"use `float_IEEE754` constructor for floating point types that "
"follow IEEE 754 conventions");
return ScalarType(exponent, mantissa, true, 0, finite_values_only, nan_repr);
}
uint8_t const exponent; // size of the exponent field (0 for integer types)
uint8_t const mantissa; // size of the mantissa field (size of the integer
// excluding the sign bit for integer types)
bool const signed_; // flag if the type supports negative numbers (i.e. has a
// sign bit)
int32_t const bias; // stored values equal value + bias,
// used for quantized type
// Extra Floating point info
bool const finite_values_only; // i.e. no +/-inf if true
NanRepr const nan_repr; // how NaNs are represented
// (not applicable for integer types)
using Id = int64_t;
private:
// Field size in id
template <typename T_>
static constexpr size_t member_id_field_width() {
using T = std::decay_t<T_>;
return std::is_same_v<T, bool> ? 1 : sizeof(T) * 8;
}
template <typename Fn, typename Init, typename Member, typename... Rest>
static constexpr auto reduce_members_helper(Fn f, Init val, Member member, Rest... rest) {
auto new_val = f(val, member);
if constexpr (sizeof...(rest) > 0) {
return reduce_members_helper(f, new_val, rest...);
} else {
return new_val;
};
}
template <typename Fn, typename Init>
constexpr auto reduce_members(Fn f, Init init) const {
// Should be in constructor order for `from_id`
return reduce_members_helper(f, init, exponent, mantissa, signed_, bias, finite_values_only, nan_repr);
};
template <typename Fn, typename Init>
static constexpr auto reduce_member_types(Fn f, Init init) {
constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE);
return dummy_type.reduce_members(f, init);
};
static constexpr auto id_size_bits() {
return reduce_member_types(
[](int acc, auto member) -> int { return acc + member_id_field_width<decltype(member)>(); }, 0);
}
public:
// unique id for this scalar type that can be computed at compile time for
// c++17 template specialization this is not needed once we migrate to
// c++20 and can pass literal classes as template parameters
constexpr Id id() const {
static_assert(id_size_bits() <= sizeof(Id) * 8, "ScalarType id is too large to be stored");
auto or_and_advance = [](std::pair<Id, uint32_t> result, auto member) -> std::pair<Id, uint32_t> {
auto [id, bit_offset] = result;
auto constexpr bits = member_id_field_width<decltype(member)>();
return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1)) << bit_offset, bit_offset + bits};
};
return reduce_members(or_and_advance, std::pair<Id, uint32_t>{}).first;
}
// create a ScalarType from an id, for c++17 template specialization,
// this is not needed once we migrate to c++20 and can pass literal
// classes as template parameters
static constexpr ScalarType from_id(Id id) {
auto extract_and_advance = [id](auto result, auto member) {
using T = decltype(member);
auto [tuple, bit_offset] = result;
auto constexpr bits = member_id_field_width<T>();
auto extracted_val = static_cast<T>((int64_t(id) >> bit_offset) & ((uint64_t(1) << bits) - 1));
auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val));
return std::pair<decltype(new_tuple), int>{new_tuple, bit_offset + bits};
};
auto [tuple_args, _] = reduce_member_types(extract_and_advance, std::pair<std::tuple<>, int>{});
return std::apply([](auto... args) { return ScalarType(args...); }, tuple_args);
}
constexpr int64_t size_bits() const {
return mantissa + exponent + is_signed();
}
constexpr bool is_signed() const {
return signed_;
}
constexpr bool is_integer() const {
return exponent == 0;
}
constexpr bool is_floating_point() const {
return exponent > 0;
}
constexpr bool is_ieee_754() const {
return is_floating_point() && finite_values_only == false && nan_repr == NAN_IEEE_754;
}
constexpr bool has_nans() const {
return is_floating_point() && nan_repr != NAN_NONE;
}
constexpr bool has_infs() const {
return is_floating_point() && finite_values_only == false;
}
constexpr bool has_bias() const {
return bias != 0;
}
private:
double _floating_point_max() const {
TORCH_CHECK(mantissa <= 52 && exponent <= 11, "Cannot represent max/min as a double for type ", str());
uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1;
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) {
max_mantissa -= 1;
}
uint64_t max_exponent = (uint64_t(1) << exponent) - 2;
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) {
TORCH_CHECK(exponent < 11, "Cannot represent max/min as a double for type ", str());
max_exponent += 1;
}
// adjust the exponent to match that of a double
// for now we assume the exponent bias is the standard 2^(e-1) -1, (where e
// is the exponent bits), there is some precedent for non-standard biases,
// example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes
// but to avoid premature over complication we are just assuming the
// standard exponent bias until there is a need to support non-standard
// biases
uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1;
uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11
uint64_t max_exponent_double = max_exponent - exponent_bias + exponent_bias_double;
// shift the mantissa into the position for a double and
// the exponent
uint64_t double_raw = (max_mantissa << (52 - mantissa)) | (max_exponent_double << 52);
return *reinterpret_cast<double*>(&double_raw);
}
constexpr std::variant<int64_t, double> _raw_max() const {
if (is_floating_point()) {
return {_floating_point_max()};
} else {
TORCH_CHECK(size_bits() < 64 || size_bits() == 64 && is_signed(), "Cannot represent max as a int64_t");
return {(int64_t(1) << mantissa) - 1};
}
}
constexpr std::variant<int64_t, double> _raw_min() const {
if (is_floating_point()) {
TORCH_CHECK(is_signed(), "We currently assume all floating point types are signed");
constexpr uint64_t sign_bit_double = (uint64_t(1) << 63);
double max = _floating_point_max();
uint64_t max_raw = *reinterpret_cast<uint64_t*>(&max);
uint64_t min_raw = max_raw | sign_bit_double;
return {*reinterpret_cast<double*>(&min_raw)};
} else {
TORCH_CHECK(!is_signed() || size_bits() <= 64, "Cannot represent min as a int64_t");
if (is_signed()) {
// set the top bit to 1 (i.e. INT64_MIN) and the rest to 0
// then perform an arithmetic shift right to set all the bits above
// (size_bits() - 1) to 1
return {INT64_MIN >> (64 - size_bits())};
} else {
return {int64_t(0)};
}
}
}
public:
// Max representable value for this scalar type.
// (accounting for bias if there is one)
constexpr std::variant<int64_t, double> max() const {
return std::visit([this](auto x) -> std::variant<int64_t, double> { return {x - bias}; }, _raw_max());
}
// Min representable value for this scalar type.
// (accounting for bias if there is one)
constexpr std::variant<int64_t, double> min() const {
return std::visit([this](auto x) -> std::variant<int64_t, double> { return {x - bias}; }, _raw_min());
}
std::string str() const {
/* naming generally follows: https://github.com/jax-ml/ml_dtypes
* for floating point types (leading f) the scheme is:
* `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
* flags:
* - no-flags: means it follows IEEE 754 conventions
* - f: means finite values only (no infinities)
* - n: means nans are supported (non-standard encoding)
* for integer types the scheme is:
* `[u]int<size_bits>[b<bias>]`
* - if bias is not present it means its zero
*/
if (is_floating_point()) {
auto ret =
"float" + std::to_string(size_bits()) + "_e" + std::to_string(exponent) + "m" + std::to_string(mantissa);
if (!is_ieee_754()) {
if (finite_values_only) {
ret += "f";
}
if (nan_repr != NAN_NONE) {
ret += "n";
}
}
return ret;
} else {
auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits());
if (has_bias()) {
ret += "b" + std::to_string(bias);
}
return ret;
}
}
constexpr bool operator==(ScalarType const& other) const {
return mantissa == other.mantissa && exponent == other.exponent && bias == other.bias && signed_ == other.signed_ &&
finite_values_only == other.finite_values_only && nan_repr == other.nan_repr;
}
};
using ScalarTypeId = ScalarType::Id;
// "rust style" names generally following:
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70
static inline constexpr auto kS4 = ScalarType::int_(4);
static inline constexpr auto kU4 = ScalarType::uint(4);
static inline constexpr auto kU4B8 = ScalarType::uint(4, 8);
static inline constexpr auto kS8 = ScalarType::int_(8);
static inline constexpr auto kU8 = ScalarType::uint(8);
static inline constexpr auto kU8B128 = ScalarType::uint(8, 128);
static inline constexpr auto kFE2M1f = ScalarType::float_(2, 1, true, ScalarType::NAN_NONE);
static inline constexpr auto kFE3M2f = ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
static inline constexpr auto kFE4M3fn = ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2);
static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7);
static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10);
// Fixed width style names, generally following:
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57
static inline constexpr auto kInt4 = kS4;
static inline constexpr auto kUint4 = kU4;
static inline constexpr auto kUint4b8 = kU4B8;
static inline constexpr auto kInt8 = kS8;
static inline constexpr auto kUint8 = kU8;
static inline constexpr auto kUint8b128 = kU8B128;
static inline constexpr auto kFloat4_e2m1f = kFE2M1f;
static inline constexpr auto kFloat6_e3m2f = kFE3M2f;
static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn;
static inline constexpr auto kFloat8_e5m2 = kFE5M2;
static inline constexpr auto kFloat16_e8m7 = kFE8M7;
static inline constexpr auto kFloat16_e5m10 = kFE5M10;
// colloquial names
static inline constexpr auto kHalf = kFE5M10;
static inline constexpr auto kFloat16 = kHalf;
static inline constexpr auto kBFloat16 = kFE8M7;
static inline constexpr auto kFloat16Id = kFloat16.id();
}; // namespace sglang

View File

@@ -0,0 +1,86 @@
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#pragma once
#include <ATen/ATen.h>
#include <ATen/Tensor.h>
#include <Python.h>
#include <torch/library.h>
#include <torch/torch.h>
#include <vector>
#include "sgl_kernel_torch_shim.h"
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
#define _CONCAT(A, B) A##B
#define CONCAT(A, B) _CONCAT(A, B)
#define _STRINGIFY(A) #A
#define STRINGIFY(A) _STRINGIFY(A)
#define REGISTER_EXTENSION(NAME) \
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, STRINGIFY(NAME), nullptr, 0, nullptr}; \
return PyModule_Create(&module); \
}
/*
* From flash-attention
*/
std::vector<at::Tensor> mha_fwd(
at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size,
// h_k, d) if there is page_table.
const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages,
// page_size, h_k, dv) if there is page_table.
std::optional<const at::Tensor>&
k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
std::optional<const at::Tensor>&
v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
std::optional<const at::Tensor>& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
std::optional<at::Tensor>& out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
std::optional<const at::Tensor>& cu_seqlens_q_, // b+1
std::optional<const at::Tensor>& cu_seqlens_k_, // b+1
std::optional<const at::Tensor>& cu_seqlens_k_new_, // b+1
std::optional<const at::Tensor>&
seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
std::optional<const at::Tensor>&
seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
std::optional<int> max_seqlen_q_,
// TODO: check if we need max_seqlen_k
std::optional<int> max_seqlen_k_,
std::optional<const at::Tensor>& page_table_, // (b_k, max_num_pages_per_seq)
std::optional<const at::Tensor>& kv_batch_idx_, // b. indices to index into the KV cache
std::optional<const at::Tensor>& leftpad_k_, // b
std::optional<const at::Tensor>& rotary_cos_, // seqlen_ro x (rotary_dim / 2)
std::optional<const at::Tensor>& rotary_sin_, // seqlen_ro x (rotary_dim / 2)
std::optional<const at::Tensor>& seqlens_rotary_, // b
std::optional<at::Tensor>& q_descale_, // (b, h_k), not (b, h)
std::optional<at::Tensor>& k_descale_, // (b, h_k)
std::optional<at::Tensor>& v_descale_, // (b, h_k)
float const softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
float const softcap,
bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
std::optional<at::Tensor>& scheduler_metadata_, // (b + 1)
int num_splits,
std::optional<bool> pack_gqa_,
int const sm_margin,
std::optional<const at::Tensor>& sinks_);

View File

@@ -0,0 +1,758 @@
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#pragma once
#include <ATen/ATen.h>
#include <ATen/Tensor.h>
#include <Python.h>
#include <torch/all.h>
#include <torch/library.h>
#include <torch/torch.h>
#include <tuple>
#include <vector>
#include "scalar_type.hpp"
#define _CONCAT(A, B) A##B
#define CONCAT(A, B) _CONCAT(A, B)
#define _STRINGIFY(A) #A
#define STRINGIFY(A) _STRINGIFY(A)
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
#define REGISTER_EXTENSION(NAME) \
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, STRINGIFY(NAME), nullptr, 0, nullptr}; \
return PyModule_Create(&module); \
}
using fptr_t = int64_t;
/*
* From csrc/allreduce
*/
#ifdef USE_ROCM
// ROCM custom allreduce
fptr_t init_custom_ar(
torch::Tensor& meta,
torch::Tensor& rank_data,
const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets,
int64_t rank,
bool full_nvlink);
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, torch::Tensor& out);
void dispose(fptr_t _fa);
int64_t meta_size();
void register_buffer(
fptr_t _fa, torch::Tensor& t, const std::vector<std::string>& handles, const std::vector<int64_t>& offsets);
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
void register_graph_buffers(
fptr_t _fa, const std::vector<std::string>& handles, const std::vector<std::vector<int64_t>>& offsets);
torch::Tensor allocate_meta_buffer(int64_t size);
torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp);
// quick allreduce
fptr_t init_custom_qr(int64_t rank, int64_t world_size, std::optional<int64_t> qr_max_size = std::nullopt);
void qr_destroy(fptr_t _fa);
torch::Tensor qr_get_handle(fptr_t _fa);
void qr_open_handles(fptr_t _fa, const std::vector<torch::Tensor>& handles);
void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, int64_t quant_level, bool cast_bf2half = false);
int64_t qr_max_size();
#else
// custom allreduce
fptr_t
init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs, torch::Tensor& rank_data, int64_t rank, bool full_nvlink);
void dispose(fptr_t _fa);
int64_t meta_size();
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes);
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs);
void register_graph_buffers(
fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets);
// mscclpp
torch::Tensor mscclpp_generate_unique_id();
fptr_t mscclpp_init_context(
const torch::Tensor& unique_id,
const int64_t rank,
const int64_t world_size,
torch::Tensor& scratch,
torch::Tensor& put_buffer,
const int64_t nranks_per_node,
const std::vector<int64_t>& rank_to_node,
const std::vector<int64_t>& rank_to_ib,
const int64_t context_selection);
void mscclpp_allreduce(fptr_t _context, torch::Tensor& inp, torch::Tensor& out, int64_t nthreads, int64_t nblocks);
#endif
/*
* From csrc/attention
*/
void lightning_attention_decode(
const torch::Tensor& q,
const torch::Tensor& k,
const torch::Tensor& v,
const torch::Tensor& past_kv,
const torch::Tensor& slope,
torch::Tensor output,
torch::Tensor new_kv);
void merge_state(
at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged);
void merge_state_v2(
at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged);
void cutlass_mla_decode(
torch::Tensor const& out,
torch::Tensor const& q_nope,
torch::Tensor const& q_pe,
torch::Tensor const& kv_c_and_k_pe_cache,
torch::Tensor const& seq_lens,
torch::Tensor const& page_table,
torch::Tensor const& workspace,
double sm_scale,
int64_t num_kv_splits = 1 /* Set to 1 to avoid cuda_graph issue by default. */);
int64_t cutlass_mla_get_workspace_size(
int64_t max_seq_len,
int64_t num_batches,
int64_t sm_count = 0,
int64_t num_kv_splits = 1 /* Set to 1 to avoid cuda_graph issue by default. */);
/*
* From csrc/elementwise
*/
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl);
void sgl_fused_add_rmsnorm(
torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps, bool enable_pdl);
void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl);
void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, bool enable_pdl);
void silu_and_mul(at::Tensor& out, at::Tensor& input);
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input);
void gelu_and_mul(at::Tensor& out, at::Tensor& input);
void apply_rope_pos_ids_cos_sin_cache(
at::Tensor q,
at::Tensor k,
at::Tensor q_rope,
at::Tensor k_rope,
at::Tensor cos_sin_cache,
at::Tensor pos_ids,
bool interleave,
bool enable_pdl,
int64_t cuda_stream,
const std::optional<at::Tensor>& v,
const std::optional<at::Tensor>& k_buffer,
const std::optional<at::Tensor>& v_buffer,
const std::optional<at::Tensor>& kv_cache_loc);
void downcast_fp8(
at::Tensor& k,
at::Tensor& v,
at::Tensor& k_out,
at::Tensor& v_out,
at::Tensor& k_scale,
at::Tensor& v_scale,
at::Tensor& loc,
int64_t mult,
int64_t offset,
int64_t cuda_stream);
#ifdef USE_ROCM
void gelu_quick(at::Tensor& out, const at::Tensor& input);
#endif
/*
* From csrc/gemm
*/
torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch::Tensor qzeros);
void cutlass_scaled_fp4_mm(
torch::Tensor& D,
torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha);
torch::Tensor int8_scaled_mm(
const torch::Tensor& mat_a,
const torch::Tensor& mat_b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const torch::Dtype& out_dtype,
const c10::optional<torch::Tensor>& bias);
torch::Tensor fp8_scaled_mm(
const torch::Tensor& mat_a,
const torch::Tensor& mat_b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const torch::Dtype& out_dtype,
const c10::optional<torch::Tensor>& bias);
torch::Tensor fp8_blockwise_scaled_mm(
const torch::Tensor& mat_a,
const torch::Tensor& mat_b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const torch::Dtype& out_dtype);
void scaled_fp4_quant(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_scale, torch::Tensor const& input_scale);
void sgl_per_token_group_quant_fp8(
at::Tensor input,
at::Tensor output_q,
at::Tensor output_s,
int64_t group_size,
double eps,
double fp8_min,
double fp8_max,
bool scale_ue8m0);
void sgl_per_token_group_quant_int8(
at::Tensor input,
at::Tensor output_q,
at::Tensor output_s,
int64_t group_size,
double eps,
double int8_min,
double int8_max);
void sgl_per_tensor_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, bool is_static);
void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s);
void bmm_fp8(
at::Tensor A,
at::Tensor B,
at::Tensor D,
at::Tensor A_scale,
at::Tensor B_scale,
at::Tensor workspace_buffer,
int64_t cublas_handle,
int64_t cuda_stream);
void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a, const torch::Tensor& mat_b);
void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch::Tensor const& mat_b);
torch::Tensor gptq_marlin_gemm(
torch::Tensor& a,
std::optional<torch::Tensor> c_or_none,
torch::Tensor& b_q_weight,
torch::Tensor& b_scales,
std::optional<torch::Tensor> const& global_scale_or_none,
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none,
torch::Tensor& workspace,
sglang::ScalarTypeId const& b_q_type_id,
int64_t size_m,
int64_t size_n,
int64_t size_k,
bool is_k_full,
bool use_atomic_add,
bool use_fp32_reduce,
bool is_zp_float);
torch::Tensor gptq_gemm(
torch::Tensor a,
torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales,
torch::Tensor b_g_idx,
bool use_shuffle,
int64_t bit);
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
torch::Tensor
gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits);
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits);
/*
* From csrc/moe
*/
void moe_align_block_size(
torch::Tensor topk_ids,
int64_t num_experts,
int64_t block_size,
torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad,
torch::Tensor cumsum_buffer,
bool pad_sorted_token_ids);
void topk_softmax(
torch::Tensor& topk_weights, torch::Tensor& topk_indices, torch::Tensor& gating_output, bool renormalize);
std::vector<at::Tensor> moe_fused_gate(
at::Tensor& input,
at::Tensor& bias,
int64_t num_expert_group,
int64_t topk_group,
int64_t topk,
int64_t num_fused_shared_experts,
double routed_scaling_factor,
bool apply_routed_scaling_factor_on_output);
void fp8_blockwise_scaled_grouped_mm(
torch::Tensor& output,
torch::Tensor& a_ptrs,
torch::Tensor& b_ptrs,
torch::Tensor& out_ptrs,
torch::Tensor& a_scales_ptrs,
torch::Tensor& b_scales_ptrs,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const torch::Tensor& stride_a,
const torch::Tensor& stride_b,
const torch::Tensor& stride_c,
const torch::Tensor& layout_sfa,
const torch::Tensor& layout_sfb,
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets,
const torch::Tensor& workspace);
void prepare_moe_input(
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
const std::optional<torch::Tensor>& blockscale_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation,
torch::Tensor& output_permutation,
const int64_t num_experts,
const int64_t n,
const int64_t k);
void ep_moe_pre_reorder(
torch::Tensor input,
torch::Tensor gateup_input,
torch::Tensor src2dst,
torch::Tensor topk_ids,
torch::Tensor a1_scales,
int64_t start_expert_id,
int64_t end_expert_id,
int64_t topk,
bool use_per_token_if_dynamic);
void ep_moe_silu_and_mul(
torch::Tensor gateup_output,
torch::Tensor down_input,
torch::Tensor reorder_topk_ids,
torch::Tensor scales,
int64_t start_expert_id,
int64_t end_expert_id);
void ep_moe_post_reorder(
torch::Tensor down_output,
torch::Tensor output,
torch::Tensor src2dst,
torch::Tensor topk_ids,
torch::Tensor topk_weights,
int64_t start_expert_id,
int64_t end_expert_id,
int64_t topk);
void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor);
void apply_shuffle_mul_sum(
const torch::Tensor& input,
torch::Tensor& output,
const torch::Tensor& permutation,
const std::optional<torch::Tensor>& factors);
void cutlass_fp4_group_mm(
torch::Tensor& output,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& a_blockscale,
const torch::Tensor& b_blockscales,
const torch::Tensor& alphas,
const torch::Tensor& ab_strides,
const torch::Tensor& c_strides,
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets,
const torch::Tensor& sf_offsets);
void scaled_fp4_experts_quant(
torch::Tensor& output,
torch::Tensor& output_scale,
torch::Tensor const& input,
torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
void silu_and_mul_scaled_fp4_experts_quant(
torch::Tensor& output,
torch::Tensor& output_scale,
torch::Tensor const& input,
torch::Tensor const& input_global_scale,
torch::Tensor const& mask,
bool use_silu_and_mul);
/*
* From csrc/moe/cutlass_moe/w4a8
*/
void get_cutlass_w4a8_moe_mm_data(
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation,
torch::Tensor& output_permutation,
const int64_t num_experts,
const int64_t n,
const int64_t k);
void cutlass_w4a8_moe_mm(
torch::Tensor& d_tensors,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes,
torch::Tensor const& a_strides,
torch::Tensor const& b_strides,
torch::Tensor const& d_strides,
torch::Tensor const& s_strides,
int64_t chunk_size,
int64_t topk);
torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a,
std::optional<torch::Tensor> const& c_or_none,
torch::Tensor& b_q_weight,
torch::Tensor& b_scales,
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none,
torch::Tensor& workspace,
torch::Tensor& sorted_token_ids,
torch::Tensor& expert_ids,
torch::Tensor& num_tokens_past_padded,
torch::Tensor& topk_weights,
int64_t moe_block_size,
int64_t top_k,
bool mul_topk_weights,
bool is_ep,
sglang::ScalarTypeId const& b_q_type_id,
int64_t size_m,
int64_t size_n,
int64_t size_k,
bool is_k_full,
bool use_atomic_add,
bool use_fp32_reduce,
bool is_zp_float);
/*
* From csrc/speculative
*/
void tree_speculative_sampling_target_only(
at::Tensor predicts, // mutable
at::Tensor accept_index, // mutable
at::Tensor accept_token_num, // mutable
at::Tensor candidates,
at::Tensor retrive_index,
at::Tensor retrive_next_token,
at::Tensor retrive_next_sibling,
at::Tensor uniform_samples,
at::Tensor uniform_samples_for_final_sampling,
at::Tensor target_probs,
at::Tensor draft_probs,
double threshold_single = 1,
double threshold_acc = 1,
bool deterministic = true,
int64_t cuda_stream = 0);
void verify_tree_greedy(
at::Tensor predicts, // mutable
at::Tensor accept_index, // mutable
at::Tensor accept_token_num, // mutable
at::Tensor candidates,
at::Tensor retrive_index,
at::Tensor retrive_next_token,
at::Tensor retrive_next_sibling,
at::Tensor target_predict,
int64_t cuda_stream = 0);
void build_tree_kernel_efficient(
at::Tensor parent_list,
at::Tensor selected_index,
at::Tensor verified_seq_len,
at::Tensor tree_mask,
at::Tensor positions,
at::Tensor retrive_index,
at::Tensor retrive_next_token,
at::Tensor retrive_next_sibling,
int64_t topk,
int64_t depth,
int64_t draft_token_num,
int64_t tree_mask_mode);
void segment_packbits(
at::Tensor x,
at::Tensor input_indptr,
at::Tensor output_indptr,
at::Tensor y,
int64_t batch_size,
int64_t cuda_stream = 0);
/*
* From csrc/kvcacheio
*/
void transfer_kv_per_layer(
const at::Tensor src_k,
at::Tensor dst_k,
const at::Tensor src_v,
at::Tensor dst_v,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t item_size,
int64_t block_quota,
int64_t num_warps_per_block);
void transfer_kv_per_layer_pf_lf(
const at::Tensor src_k,
at::Tensor dst_k,
const at::Tensor src_v,
at::Tensor dst_v,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t layer_id,
int64_t item_size,
int64_t src_layout_dim,
int64_t block_quota,
int64_t num_warps_per_block);
void transfer_kv_all_layer(
const at::Tensor src_k_layers,
const at::Tensor dst_k_layers,
const at::Tensor src_v_layers,
const at::Tensor dst_v_layers,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t item_size,
int64_t num_layers,
int64_t block_quota,
int64_t num_warps_per_block);
void transfer_kv_all_layer_lf_pf(
const at::Tensor src_k_layers,
at::Tensor dst_k,
const at::Tensor src_v_layers,
at::Tensor dst_v,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t item_size,
int64_t dst_layout_dim,
int64_t num_layers,
int64_t block_quota,
int64_t num_warps_per_block);
void transfer_kv_per_layer_mla(
const at::Tensor src,
at::Tensor dst,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t item_size,
int64_t block_quota,
int64_t num_warps_per_block);
void transfer_kv_per_layer_mla_pf_lf(
const at::Tensor src,
at::Tensor dst,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t layer_id,
int64_t item_size,
int64_t src_layout_dim,
int64_t block_quota,
int64_t num_warps_per_block);
void transfer_kv_all_layer_mla(
const at::Tensor src_layers,
const at::Tensor dst_layers,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t item_size,
int64_t num_layers,
int64_t block_quota,
int64_t num_warps_per_block);
void transfer_kv_all_layer_mla_lf_pf(
const at::Tensor src_layers,
at::Tensor dst,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t item_size,
int64_t dst_layout_dim,
int64_t num_layers,
int64_t block_quota,
int64_t num_warps_per_block);
void transfer_kv_direct(
const std::vector<at::Tensor>& src_layers,
std::vector<at::Tensor> dst_layers,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t page_size);
/*
* From FlashInfer
*/
void min_p_sampling_from_probs(
at::Tensor probs,
at::Tensor output,
std::optional<at::Tensor> maybe_indices,
std::optional<at::Tensor> maybe_min_p_arr,
double min_p_val,
bool deterministic,
std::optional<at::Generator> gen);
void top_k_renorm_probs(
at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val);
void top_p_renorm_probs(
at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_p_arr, double top_p_val);
void top_k_top_p_sampling_from_probs(
at::Tensor probs,
at::Tensor output,
std::optional<at::Tensor> maybe_indices,
std::optional<at::Tensor> maybe_top_k_arr,
double top_k_val,
std::optional<at::Tensor> maybe_top_p_arr,
double top_p_val,
bool deterministic,
std::optional<at::Generator> gen);
void top_p_sampling_from_probs(
at::Tensor probs,
at::Tensor output,
std::optional<at::Tensor> maybe_indices,
std::optional<at::Tensor> maybe_top_p_arr,
double top_p_val,
bool deterministic,
std::optional<at::Generator> gen);
void top_k_mask_logits(
at::Tensor logits, at::Tensor mask_logits, std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val);
namespace flash {
/*
* From fa2 sparse
*/
std::vector<at::Tensor> mha_fwd_sparse(
at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor& block_count,
const at::Tensor& block_offset,
const at::Tensor& column_count,
const at::Tensor& column_index,
const std::optional<at::Tensor>& out_, // batch_size x seqlen_q x num_heads x head_size
const std::optional<at::Tensor>& alibi_slopes_, // num_heads or batch_size x num_heads
const double p_dropout,
const double softmax_scale,
bool is_causal,
const double softcap,
const bool return_softmax,
std::optional<at::Generator> gen_);
std::vector<at::Tensor> mha_varlen_fwd_sparse(
at::Tensor& q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor& k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i.
const at::Tensor& v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i.
const at::Tensor& block_count,
const at::Tensor& block_offset,
const at::Tensor& column_count,
const at::Tensor& column_index,
const c10::optional<at::Tensor>& out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor& cu_seqlens_q, // b+1
const at::Tensor& cu_seqlens_k, // b+1
const c10::optional<at::Tensor>&
seqused_k, // b. If given, only this many elements of each batch element's keys are used.
const c10::optional<at::Tensor>& alibi_slopes_, // num_heads or b x num_heads
int64_t max_seqlen_q,
const int64_t max_seqlen_k,
const double p_dropout,
const double softmax_scale,
const bool zero_tensors,
bool is_causal,
const double softcap,
const bool return_softmax,
c10::optional<at::Generator> gen_);
} // namespace flash
void convert_vertical_slash_indexes(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch::Tensor q_seqlens, // [BATCH, ]
torch::Tensor kv_seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int64_t context_size,
int64_t block_size_M,
int64_t block_size_N,
bool causal);
void convert_vertical_slash_indexes_mergehead(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch::Tensor q_seqlens, // [BATCH, ]
torch::Tensor kv_seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
torch::Tensor vertical_indices_count, // [N_HEADS, ]
torch::Tensor slash_indices_count,
int64_t context_size,
int64_t block_size_M,
int64_t block_size_N,
bool causal);
/*
* From XGrammar
*/
void ApplyTokenBitmaskInplace(at::Tensor logits, at::Tensor bitmask, at::optional<at::Tensor> indices = at::nullopt);
/*
* From QServe
*/
void qserve_w4a8_per_chn_gemm(
const torch::Tensor& _in_feats,
const torch::Tensor& _kernel,
const torch::Tensor& _wscales,
const torch::Tensor& _ascales,
const torch::Tensor& _w_szs,
const torch::Tensor& _a_ssums,
torch::Tensor& _out_feats);
void qserve_w4a8_per_group_gemm(
const torch::Tensor& _in_feats,
const torch::Tensor& _kernel,
const torch::Tensor& _zeros,
const torch::Tensor& _scales_i8,
const torch::Tensor& _wscales,
const torch::Tensor& _ascales,
torch::Tensor& _out_feats);
/*
* From csrc/spatial
*/
std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device);
/*
* From csrc/memory
*/
void store_kv_cache(at::Tensor k_cache, at::Tensor v_cache, at::Tensor out_loc, at::Tensor k, at::Tensor v);

View File

@@ -0,0 +1,122 @@
/*Adapt from:
https://github.com/neuralmagic/vllm-flash-attention/blob/90eacc1af2a7c3de62ea249e929ed5faccf38954/csrc/common/pytorch_shim.h
Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#pragma once
#include <torch/library.h>
/**
* Unfortunately, the type signatures of the flash_attn ops are not compatible
* with the PyTorch library bindings. To get around that we use
* `make_pytorch_shim` which creates a lambda that exposes the API using
* PyTorch compatible types to the types, then converts them to the types
* expected by the flash_attn ops. This shims allows us to make minimal changes
* to `flash_api.cpp` making it easier to synchronize with upstream changes.
*
* The `pytorch_library_compatible_type` struct is used to map from the
* flash_attn ops types to a PyTorch library compatible one. The main issues is
* that the following types are not support by PyTorch library bindings:
* - `int`
* - `float`
* - `std::optional<T> &`
* - `std::optional<const at::Tensor> &`
* So we convert them to (respectively):
* - `int64_t`
* - `double`
* - `const std::optional<T>&`
* - `const std::optional<at::Tensor>&`
*/
template <typename T>
struct pytorch_library_compatible_type {
using type = T;
static T convert_from_type(T arg) {
return arg;
}
};
template <typename T>
using pytorch_library_compatible_type_t = typename pytorch_library_compatible_type<T>::type;
template <typename T>
T convert_from_pytorch_compatible_type(pytorch_library_compatible_type_t<T> arg) {
return pytorch_library_compatible_type<T>::convert_from_type(arg);
}
// Map `c10::optional<T> &` -> `const c10::optional<T>&`
// (NOTE: this is bit unsafe but non of the ops in flash_attn mutate
// the optional container)
template <typename T>
struct pytorch_library_compatible_type<c10::optional<T>&> {
using type = const c10::optional<T>&;
static c10::optional<T>& convert_from_type(const c10::optional<T>& arg) {
return const_cast<c10::optional<T>&>(arg);
}
};
// Map `c10::optional<T>` ->
// `c10::optional<pytorch_library_compatible_type_t<T>>`
// (NOTE: tested for `c10::optional<int>` -> `c10::optional<int64_t>`)
template <typename T>
struct pytorch_library_compatible_type<c10::optional<T>> {
using type = c10::optional<pytorch_library_compatible_type_t<T>>;
static c10::optional<pytorch_library_compatible_type_t<T>> convert_from_type(c10::optional<T> arg) {
return arg;
}
};
// Map `c10::optional<const at::Tensor>&` -> `const c10::optional<at::Tensor>&`
template <>
struct pytorch_library_compatible_type<c10::optional<const at::Tensor>&> {
using type = const c10::optional<at::Tensor>&;
static c10::optional<const at::Tensor>& convert_from_type(const c10::optional<at::Tensor>& arg) {
return const_cast<c10::optional<const at::Tensor>&>(reinterpret_cast<const c10::optional<const at::Tensor>&>(arg));
}
};
// Map `int` -> `int64_t`
template <>
struct pytorch_library_compatible_type<int> {
using type = int64_t;
static int convert_from_type(int64_t arg) {
TORCH_CHECK(arg <= std::numeric_limits<int>::max(), "int64_t value is too large to be converted to int");
TORCH_CHECK(arg >= std::numeric_limits<int>::min(), "int64_t value is too small to be converted to int");
return arg;
}
};
// Map `float` -> `double`
template <>
struct pytorch_library_compatible_type<float> {
using type = double;
static float convert_from_type(double arg) {
TORCH_CHECK(
std::abs(arg) <= std::numeric_limits<float>::max(), "double value is too large to be converted to float");
return arg;
}
};
//
// Shim Utils
//
template <typename Ret, typename... Args>
auto make_pytorch_shim(Ret (*fun)(Args... args)) {
return [fun](pytorch_library_compatible_type_t<Args>... args) {
return fun(convert_from_pytorch_compatible_type<Args>(args)...);
};
}

449
sgl-kernel/include/utils.h Normal file
View File

@@ -0,0 +1,449 @@
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#pragma once
#include <ATen/Tensor.h>
#include <cuda_runtime.h>
#include <torch/all.h>
#ifdef USE_ROCM
// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491)
#define _DISPATCH_CASE_F16(c_type, ...) \
case at::ScalarType::Half: { \
using c_type = __half; \
return __VA_ARGS__(); \
}
#define _DISPATCH_CASE_BF16(c_type, ...) \
case at::ScalarType::BFloat16: { \
using c_type = __hip_bfloat16; \
return __VA_ARGS__(); \
}
#endif // USE_ROCM
#ifndef USE_ROCM
// Adapt from FlashInfer
#ifdef FLASHINFER_ENABLE_F16
#define _DISPATCH_CASE_F16(c_type, ...) \
case at::ScalarType::Half: { \
using c_type = nv_half; \
return __VA_ARGS__(); \
}
#else
#define _DISPATCH_CASE_F16(c_type, ...)
#endif // FLASHINFER_ENABLE_F16
#ifdef FLASHINFER_ENABLE_BF16
#define _DISPATCH_CASE_BF16(c_type, ...) \
case at::ScalarType::BFloat16: { \
using c_type = nv_bfloat16; \
return __VA_ARGS__(); \
}
#else
#define _DISPATCH_CASE_BF16(c_type, ...)
#endif // FLASHINFER_ENABLE_BF16
#ifdef FLASHINFER_ENABLE_FP8_E4M3
#define _DISPATCH_CASE_FP8_E4M3(c_type, ...) \
case at::ScalarType::Float8_e4m3fn: { \
using c_type = __nv_fp8_e4m3; \
return __VA_ARGS__(); \
}
#else
#define _DISPATCH_CASE_FP8_E4M3(c_type, ...)
#endif // FLASHINFER_ENABLE_FP8_E4M3
#ifdef FLASHINFER_ENABLE_FP8_E5M2
#define _DISPATCH_CASE_FP8_E5M2(c_type, ...) \
case at::ScalarType::Float8_e5m2: { \
using c_type = __nv_fp8_e5m2; \
return __VA_ARGS__(); \
}
#else
#define _DISPATCH_CASE_FP8_E5M2(c_type, ...)
#endif // FLASHINFER_ENABLE_FP8_E5M2
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \
[&]() -> bool { \
switch (pytorch_dtype) { \
_DISPATCH_CASE_F16(c_type, __VA_ARGS__) \
_DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \
default: \
std::ostringstream oss; \
oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \
TORCH_CHECK(false, oss.str()); \
return false; \
} \
}()
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(pytorch_dtype, c_type, ...) \
[&]() -> bool { \
switch (pytorch_dtype) { \
_DISPATCH_CASE_FP8_E4M3(c_type, __VA_ARGS__) \
_DISPATCH_CASE_FP8_E5M2(c_type, __VA_ARGS__) \
default: \
std::ostringstream oss; \
oss << __PRETTY_FUNCTION__ << " failed to dispatch fp8 data type " << pytorch_dtype; \
TORCH_CHECK(false, oss.str()); \
return false; \
} \
}()
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \
[&]() -> bool { \
switch (pytorch_dtype) { \
_DISPATCH_CASE_F16(c_type, __VA_ARGS__) \
_DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \
_DISPATCH_CASE_FP8_E4M3(c_type, __VA_ARGS__) \
_DISPATCH_CASE_FP8_E5M2(c_type, __VA_ARGS__) \
default: \
std::ostringstream oss; \
oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \
TORCH_CHECK(false, oss.str()); \
return false; \
} \
}()
#define _DISPATCH_SWITCH(var_name, cond, ...) \
[&]() -> bool { \
switch (cond) { \
__VA_ARGS__ \
default: \
std::ostringstream oss; \
oss << __PRETTY_FUNCTION__ << " failed to dispatch " var_name " " << int(cond); \
TORCH_CHECK(false, oss.str()); \
return false; \
} \
}()
#define _DISPATCH_SWITCH_U16x2(var1_name, var2_name, cond1, cond2, ...) \
[&]() -> bool { \
switch (pack_u16(cond1, cond2)) { \
__VA_ARGS__ \
default: \
std::ostringstream oss; \
oss << __PRETTY_FUNCTION__ << " failed to dispatch (" var1_name ", " var2_name "): (" << int(cond1) << ", " \
<< int(cond2) << ")"; \
TORCH_CHECK(false, oss.str()); \
return false; \
} \
}()
#define _DISPATCH_CASE(case_expr, case_var, ...) \
case case_expr: { \
constexpr auto case_var = case_expr; \
return __VA_ARGS__(); \
}
#define _DISPATCH_CASE_U16x2(case_expr1, case_expr2, case_var1, case_var2, ...) \
case pack_u16(case_expr1, case_expr2): { \
constexpr auto case_var1 = case_expr1; \
constexpr auto case_var2 = case_expr2; \
return __VA_ARGS__(); \
}
#define DISPATCH_BOOL(expr, const_expr, ...) \
[&]() -> bool { \
if (expr) { \
constexpr bool const_expr = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool const_expr = false; \
return __VA_ARGS__(); \
} \
}()
inline void check_shape(const at::Tensor& a, const at::Tensor& b, const char* a_name, const char* b_name) {
TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", a.dim(), " vs ", b.dim());
for (int i = 0; i < a.dim(); ++i) {
TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name, ".size(", i, ")");
}
}
inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
return (uint32_t(a) << 16) | uint32_t(b);
}
#define CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads) \
TORCH_CHECK( \
num_qo_heads % num_kv_heads == 0, \
"num_qo_heads(", \
num_qo_heads, \
") must be divisible by num_kv_heads(", \
num_kv_heads, \
")")
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_LAST_DIM_CONTIGUOUS(x) \
TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimension")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \
CHECK_CUDA(x); \
CHECK_LAST_DIM_CONTIGUOUS(x)
#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b)
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b)
inline bool is_float8_tensor(const at::Tensor& tensor) {
return tensor.scalar_type() == at::ScalarType::Float8_e4m3fn || tensor.scalar_type() == at::ScalarType::Float8_e5m2;
}
#endif // USE_ROCM
struct cuda_error : public std::runtime_error {
/**
* @brief Constructs a `cuda_error` object with the given `message`.
*
* @param message The error char array used to construct `cuda_error`
*/
cuda_error(const char* message) : std::runtime_error(message) {}
/**
* @brief Constructs a `cuda_error` object with the given `message` string.
*
* @param message The `std::string` used to construct `cuda_error`
*/
cuda_error(std::string const& message) : cuda_error{message.c_str()} {}
};
#define CHECK_CUDA_SUCCESS(cmd) \
do { \
cudaError_t e = cmd; \
if (e != cudaSuccess) { \
std::stringstream _message; \
auto s = cudaGetErrorString(e); \
_message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__; \
throw cuda_error(_message.str()); \
} \
} while (0)
#define CHECK_IS_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_IS_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_CUDA_INPUT(x) \
CHECK_IS_CUDA(x); \
CHECK_IS_CONTIGUOUS(x)
inline int getSMVersion() {
int device{-1};
CHECK_CUDA_SUCCESS(cudaGetDevice(&device));
int sm_major = 0;
int sm_minor = 0;
CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device));
CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device));
return sm_major * 10 + sm_minor;
}
inline bool isDeviceType(const std::string& device_type) {
int deviceCount;
CHECK_CUDA_SUCCESS(cudaGetDeviceCount(&deviceCount));
int device_id = -1;
if (deviceCount >= 1) {
CHECK_CUDA_SUCCESS(cudaGetDevice(&device_id));
} else {
return false;
}
cudaDeviceProp prop;
CHECK_CUDA_SUCCESS(cudaGetDeviceProperties(&prop, device_id));
if (device_type == std::string(prop.name)) {
return true;
}
return false;
}
inline bool getBoolEnv(char const* name) {
char const* env = std::getenv(name);
return env && env[0] == '1' && env[1] == '\0';
}
inline bool getEnvEnablePDL() {
static std::once_flag flag;
static bool enablePDL = false;
std::call_once(flag, [&]() {
if (getSMVersion() >= 90) {
// PDL will be enabled by setting the env variables `TRTLLM_ENABLE_PDL` to `1`
enablePDL = getBoolEnv("TRTLLM_ENABLE_PDL");
}
});
return enablePDL;
}
// SGLANG_SHFL_XOR_* adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/csrc/cuda_compat.h#L19-L28
#ifndef USE_ROCM
#define SGLANG_SHFL_XOR_SYNC(mask, var, lane_mask) __shfl_xor_sync((mask), (var), (lane_mask))
#define SGLANG_SHFL_XOR_SYNC_WIDTH(mask, var, lane_mask, width) __shfl_xor_sync((mask), (var), (lane_mask), (width))
#else
#define SGLANG_SHFL_XOR_SYNC(mask, var, lane_mask) __shfl_xor((var), (lane_mask))
#define SGLANG_SHFL_XOR_SYNC_WIDTH(mask, var, lane_mask, width) __shfl_xor((var), (lane_mask), (width))
#endif
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(pytorch_dtype, c_type, ...) \
[&]() -> bool { \
switch (pytorch_dtype) { \
case at::ScalarType::Float: { \
using c_type = float; \
return __VA_ARGS__(); \
} \
_DISPATCH_CASE_F16(c_type, __VA_ARGS__) \
_DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \
default: \
std::ostringstream oss; \
oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \
TORCH_CHECK(false, oss.str()); \
return false; \
} \
}()
#define DISPATCH_CASE_INTEGRAL_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
#define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
#define WARP_SIZE 64
#else
#define WARP_SIZE 32
#endif
#endif
#ifdef USE_ROCM
#include "hip/hip_math_def.h"
#include "hip/hip_vec_dtypes.h"
#else
template <typename srcDtype>
__device__ __forceinline__ float castToFloat(srcDtype val) {
return static_cast<srcDtype>(val);
}
template <typename dstDtype>
__device__ __forceinline__ dstDtype castFromFloat(float val) {
return static_cast<dstDtype>(val);
}
#endif
// add FP8 support
// #ifndef USE_ROCM
// #include <c10/util/Float8_e4m3fn.h>
// using FP8_TYPE = c10::Float8_e4m3fn;
// C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
// #else // USE_ROCM
// #if HIP_FP8_TYPE_FNUZ
// #include <c10/util/Float8_e4m3fnuz.h>
// using FP8_TYPE = c10::Float8_e4m3fnuz;
// constexpr auto FP8_E4M3_MAX = 224.0f;
// #else
// #if HIP_FP8_TYPE_E4M3
// #include <c10/util/Float8_e4m3fn.h>
// using FP8_TYPE = c10::Float8_e4m3fn;
// C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
// #else
// #error "fp8 is not supported in this processor (arch < gfx942)."
// #endif // HIP_FP8_TYPE_E4M3
// #endif // HIP_FP8_TYPE_FNUZ
// #endif // USE_ROCM
#define FULL_MASK 0xffffffff
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
#ifndef USE_ROCM
float old;
old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
: __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
return old;
#else
int* addr_as_i = (int*)addr;
int old = *addr_as_i, assumed;
do {
assumed = old;
old = atomicCAS(addr_as_i, assumed, __float_as_int(fmaxf(value, __int_as_float(assumed))));
} while (assumed != old);
return __int_as_float(old);
#endif
}
__device__ __forceinline__ float warpReduceMax(float value) {
value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 16));
value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 8));
value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 4));
value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 2));
value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 1));
return value;
}
__device__ __forceinline__ float blockReduceMax(float value) {
static __shared__ float warpLevelMaxs[WARP_SIZE];
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
value = warpReduceMax(value);
if (laneId == 0) warpLevelMaxs[warpId] = value;
__syncthreads();
value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
if (warpId == 0) value = warpReduceMax(value);
return value;
}
// Pads to a multiple of `alignment` rows.
inline torch::Tensor pad_tensor(const torch::Tensor& tensor, int64_t alignment = 4, bool is_column_major = false) {
int64_t rows = tensor.size(0);
int64_t cols = tensor.size(1);
int64_t pad_rows = (alignment - (rows % alignment)) % alignment; // Compute padding size
if (pad_rows == 0) {
return tensor; // Already aligned
}
torch::Tensor padding = torch::zeros({pad_rows, cols}, tensor.options());
torch::Tensor tensor_padded = torch::cat({tensor, padding}, 0); // Pad along rows
// Ensure column-major layout
if (is_column_major) {
return tensor_padded.t().contiguous().t();
}
return tensor_padded;
}
// Get the next power of 2 of a number
inline uint32_t next_pow2(uint32_t x) noexcept {
if (x <= 1) return 1;
return 1u << (32 - __builtin_clz(x - 1));
}