adapt to sglang v0.5.2rc1 on dcu
This commit is contained in:
87
sgl-kernel/include/hip/hip_act_and_mul.cuh
Normal file
87
sgl-kernel/include/hip/hip_act_and_mul.cuh
Normal file
@@ -0,0 +1,87 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
#define kBitsToLoad 128
|
||||
#define kBytesToLoad (kBitsToLoad / 8)
|
||||
|
||||
// Adapted from
|
||||
// [flashinfer::activation::act_and_mul_kernel](https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/include/flashinfer/activation.cuh#L29)
|
||||
|
||||
namespace sgl_hip {
|
||||
namespace activation {
|
||||
|
||||
template <typename T, T (*Activation)(const T&)>
|
||||
__global__ void act_and_mul_kernel(T* __restrict__ out, const T* __restrict__ input, const int d) {
|
||||
constexpr uint32_t vec_size = kBytesToLoad / sizeof(T);
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t thread_idx = threadIdx.x;
|
||||
const int64_t stride = blockDim.x;
|
||||
const int64_t offset = token_idx * 2 * d;
|
||||
|
||||
#pragma unroll 1
|
||||
for (uint32_t idx = thread_idx; idx < d / vec_size; idx += stride) {
|
||||
sgl_hip::vec_t<T, vec_size> x_vec, y_vec, out_vec;
|
||||
x_vec.cast_load(input + offset + idx * vec_size);
|
||||
y_vec.cast_load(input + offset + d + idx * vec_size);
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < vec_size; ++i) {
|
||||
out_vec[i] = Activation(x_vec[i]) * y_vec[i];
|
||||
}
|
||||
out_vec.cast_store(out + token_idx * d + idx * vec_size);
|
||||
}
|
||||
|
||||
const int64_t remaining_offset = d - d % (stride * vec_size);
|
||||
// process the remaining elements
|
||||
#pragma unroll 1
|
||||
for (int64_t idx = thread_idx; idx < d % (stride * vec_size); idx += stride) {
|
||||
T x = input[offset + remaining_offset + idx], y = input[offset + remaining_offset + d + idx];
|
||||
out[token_idx * d + remaining_offset + idx] = Activation(x) * y;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, T (*Activation)(const T&)>
|
||||
__global__ void act_only_kernel(T* __restrict__ out, const T* __restrict__ input, const int d) {
|
||||
constexpr uint32_t vec_size = kBytesToLoad / sizeof(T);
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t thread_idx = threadIdx.x;
|
||||
const int64_t stride = blockDim.x;
|
||||
const int64_t offset = token_idx * d;
|
||||
|
||||
#pragma unroll 1
|
||||
for (uint32_t idx = thread_idx; idx < d / vec_size; idx += stride) {
|
||||
sgl_hip::vec_t<T, vec_size> x_vec, y_vec, out_vec;
|
||||
x_vec.cast_load(input + offset + idx * vec_size);
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < vec_size; ++i) {
|
||||
out_vec[i] = Activation(x_vec[i]);
|
||||
}
|
||||
out_vec.cast_store(out + token_idx * d + idx * vec_size);
|
||||
}
|
||||
|
||||
const int64_t remaining_offset = d - d % (stride * vec_size);
|
||||
// process the remaining elements
|
||||
#pragma unroll 1
|
||||
for (int64_t idx = thread_idx; idx < d % (stride * vec_size); idx += stride) {
|
||||
T x = input[offset + remaining_offset + idx];
|
||||
out[token_idx * d + remaining_offset + idx] = Activation(x);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace activation
|
||||
} // namespace sgl_hip
|
||||
94
sgl-kernel/include/hip/hip_math_def.h
Normal file
94
sgl-kernel/include/hip/hip_math_def.h
Normal file
@@ -0,0 +1,94 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef USE_ROCM
|
||||
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_common.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491)
|
||||
|
||||
namespace amdgpu {
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ T shfl_xor_sync(unsigned mask, T var, int laneMask, int width = warpSize);
|
||||
|
||||
template <typename srcDtype, typename destDtype>
|
||||
__forceinline__ __device__ destDtype cast(srcDtype val);
|
||||
|
||||
// specialization
|
||||
template <>
|
||||
__forceinline__ __device__ float shfl_xor_sync(unsigned mask, float var, int laneMask, int width) {
|
||||
return __shfl_xor(var, laneMask, width);
|
||||
}
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ int shfl_xor_sync(unsigned mask, int var, int laneMask, int width) {
|
||||
return __shfl_xor(var, laneMask, width);
|
||||
}
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ float cast(float val) {
|
||||
return val;
|
||||
}
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ float cast(__half val) {
|
||||
return __half2float(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ float cast(__hip_bfloat16 val) {
|
||||
return __bfloat162float(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ __half cast(float fval) {
|
||||
return __float2half(fval);
|
||||
}
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ __hip_bfloat16 cast(float fval) {
|
||||
return __float2bfloat16(fval);
|
||||
}
|
||||
|
||||
} // namespace amdgpu
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ T __shfl_xor_sync(unsigned mask, T var, int laneMask, int width = warpSize) {
|
||||
return amdgpu::shfl_xor_sync(mask, var, laneMask, width);
|
||||
}
|
||||
|
||||
template <typename srcDtype>
|
||||
__device__ __forceinline__ float castToFloat(srcDtype val) {
|
||||
return amdgpu::cast<srcDtype, float>(val);
|
||||
}
|
||||
|
||||
template <typename dstDtype>
|
||||
__device__ __forceinline__ dstDtype castFromFloat(float val) {
|
||||
return amdgpu::cast<float, dstDtype>(val);
|
||||
}
|
||||
|
||||
// operator overload to support flashinfer
|
||||
__host__ __device__ __forceinline__ __half operator*(const __half& x, const __half& y) {
|
||||
__half h_x = x;
|
||||
__half h_y = y;
|
||||
return __hmul(h_x, h_y);
|
||||
}
|
||||
|
||||
#endif
|
||||
101
sgl-kernel/include/hip/hip_vec_dtypes.h
Normal file
101
sgl-kernel/include/hip/hip_vec_dtypes.h
Normal file
@@ -0,0 +1,101 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#if USE_ROCM
|
||||
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_common.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491)d
|
||||
|
||||
#define SGL_HIP_INLINE inline __attribute__((always_inline)) __device__
|
||||
|
||||
namespace sgl_hip {
|
||||
|
||||
template <typename float_t, size_t vec_size>
|
||||
struct vec_t;
|
||||
|
||||
template <typename srcDtype, typename dstDtype, size_t vec_size>
|
||||
SGL_HIP_INLINE void cast_load_impl(vec_t<dstDtype, vec_size>& dst, const srcDtype* src);
|
||||
|
||||
template <typename srcDtype, typename dstDtype, size_t vec_size>
|
||||
SGL_HIP_INLINE void cast_store_impl(dstDtype* dst_ptr, const vec_t<srcDtype, vec_size>& src);
|
||||
|
||||
template <typename float_t, size_t vec_size>
|
||||
struct vec_t {
|
||||
SGL_HIP_INLINE float_t& operator[](size_t i);
|
||||
SGL_HIP_INLINE const float_t& operator[](size_t i) const;
|
||||
SGL_HIP_INLINE float_t* ptr();
|
||||
|
||||
SGL_HIP_INLINE void load(const float_t* ptr);
|
||||
SGL_HIP_INLINE void store(float_t* ptr) const;
|
||||
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_from(const vec_t<T, vec_size>& src);
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_load(const T* ptr);
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_store(T* ptr) const;
|
||||
};
|
||||
|
||||
} // namespace sgl_hip
|
||||
|
||||
// **** impl *****
|
||||
|
||||
namespace sgl_hip {
|
||||
|
||||
template <typename srcDtype, typename dstDtype, size_t vec_size>
|
||||
SGL_HIP_INLINE void cast_load_impl(vec_t<dstDtype, vec_size>& dst, const srcDtype* src_ptr) {
|
||||
if constexpr (std::is_same<srcDtype, dstDtype>::value) {
|
||||
dst.load(src_ptr);
|
||||
} else {
|
||||
vec_t<srcDtype, vec_size> tmp;
|
||||
tmp.load(src_ptr);
|
||||
dst.cast_from(tmp);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename srcDtype, typename dstDtype, size_t vec_size>
|
||||
SGL_HIP_INLINE void cast_store_impl(dstDtype* dst_ptr, const vec_t<srcDtype, vec_size>& src) {
|
||||
if constexpr (std::is_same<srcDtype, dstDtype>::value) {
|
||||
src.store(dst_ptr);
|
||||
} else {
|
||||
vec_t<dstDtype, vec_size> tmp;
|
||||
tmp.cast_from(src);
|
||||
tmp.store(dst_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename float_t, size_t vec_size>
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void vec_t<float_t, vec_size>::cast_load(const T* ptr) {
|
||||
cast_load_impl(*this, ptr);
|
||||
}
|
||||
|
||||
template <typename float_t, size_t vec_size>
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void vec_t<float_t, vec_size>::cast_store(T* ptr) const {
|
||||
cast_store_impl(ptr, *this);
|
||||
}
|
||||
|
||||
} // namespace sgl_hip
|
||||
|
||||
#include "impl/hip_vec_bf16_impl.h"
|
||||
#include "impl/hip_vec_fp32_impl.h"
|
||||
#include "impl/hip_vec_half_impl.h"
|
||||
#endif
|
||||
177
sgl-kernel/include/hip/impl/hip_vec_bf16_impl.h
Normal file
177
sgl-kernel/include/hip/impl/hip_vec_bf16_impl.h
Normal file
@@ -0,0 +1,177 @@
|
||||
#pragma once
|
||||
|
||||
#if USE_ROCM
|
||||
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_common.h>
|
||||
|
||||
// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491)
|
||||
|
||||
using nv_bfloat16 = __hip_bfloat16;
|
||||
using nv_bfloat162 = __hip_bfloat162;
|
||||
|
||||
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 make_bfloat162(const __hip_bfloat16 x, const __hip_bfloat16 y) {
|
||||
__hip_bfloat162 t;
|
||||
t.x = x;
|
||||
t.y = y;
|
||||
return t;
|
||||
}
|
||||
|
||||
namespace sgl_hip {
|
||||
|
||||
// nv_bfloat16 x 1
|
||||
template <>
|
||||
struct vec_t<nv_bfloat16, 1> {
|
||||
nv_bfloat16 data;
|
||||
SGL_HIP_INLINE nv_bfloat16& operator[](size_t i) {
|
||||
return ((nv_bfloat16*)(&data))[i];
|
||||
}
|
||||
SGL_HIP_INLINE const nv_bfloat16& operator[](size_t i) const {
|
||||
return ((const nv_bfloat16*)(&data))[i];
|
||||
}
|
||||
SGL_HIP_INLINE nv_bfloat16* ptr() {
|
||||
return reinterpret_cast<nv_bfloat16*>(&data);
|
||||
}
|
||||
SGL_HIP_INLINE void load(const nv_bfloat16* ptr);
|
||||
SGL_HIP_INLINE void store(nv_bfloat16* ptr) const;
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_from(const vec_t<T, 1>& src) {
|
||||
cast_from_impl(*this, src);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_load(const T* ptr) {
|
||||
cast_load_impl(*this, ptr);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_store(T* ptr) const {
|
||||
cast_store_impl(ptr, *this);
|
||||
}
|
||||
};
|
||||
|
||||
SGL_HIP_INLINE void vec_t<nv_bfloat16, 1>::load(const nv_bfloat16* ptr) {
|
||||
data = *ptr;
|
||||
}
|
||||
|
||||
SGL_HIP_INLINE void vec_t<nv_bfloat16, 1>::store(nv_bfloat16* ptr) const {
|
||||
*ptr = data;
|
||||
}
|
||||
|
||||
// nv_bfloat16 x 2
|
||||
template <>
|
||||
struct vec_t<nv_bfloat16, 2> {
|
||||
nv_bfloat162 data;
|
||||
|
||||
SGL_HIP_INLINE nv_bfloat16& operator[](size_t i) {
|
||||
return ((nv_bfloat16*)(&data))[i];
|
||||
}
|
||||
SGL_HIP_INLINE const nv_bfloat16& operator[](size_t i) const {
|
||||
return ((const nv_bfloat16*)(&data))[i];
|
||||
}
|
||||
SGL_HIP_INLINE nv_bfloat16* ptr() {
|
||||
return reinterpret_cast<nv_bfloat16*>(&data);
|
||||
}
|
||||
SGL_HIP_INLINE void load(const nv_bfloat16* ptr);
|
||||
SGL_HIP_INLINE void store(nv_bfloat16* ptr) const;
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_from(const vec_t<T, 2>& src) {
|
||||
cast_from_impl(*this, src);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_load(const T* ptr) {
|
||||
cast_load_impl(*this, ptr);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_store(T* ptr) const {
|
||||
cast_store_impl(ptr, *this);
|
||||
}
|
||||
};
|
||||
|
||||
SGL_HIP_INLINE void vec_t<nv_bfloat16, 2>::load(const nv_bfloat16* ptr) {
|
||||
data = *((nv_bfloat162*)ptr);
|
||||
}
|
||||
|
||||
SGL_HIP_INLINE void vec_t<nv_bfloat16, 2>::store(nv_bfloat16* ptr) const {
|
||||
*((nv_bfloat162*)ptr) = data;
|
||||
}
|
||||
|
||||
template <>
|
||||
struct vec_t<nv_bfloat16, 4> {
|
||||
uint2 data;
|
||||
|
||||
SGL_HIP_INLINE nv_bfloat16& operator[](size_t i) {
|
||||
return ((nv_bfloat16*)(&data))[i];
|
||||
}
|
||||
SGL_HIP_INLINE const nv_bfloat16& operator[](size_t i) const {
|
||||
return ((const nv_bfloat16*)(&data))[i];
|
||||
}
|
||||
SGL_HIP_INLINE nv_bfloat16* ptr() {
|
||||
return reinterpret_cast<nv_bfloat16*>(&data);
|
||||
}
|
||||
SGL_HIP_INLINE void load(const nv_bfloat16* ptr);
|
||||
SGL_HIP_INLINE void store(nv_bfloat16* ptr) const;
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_from(const vec_t<T, 4>& src) {
|
||||
cast_from_impl(*this, src);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_load(const T* ptr) {
|
||||
cast_load_impl(*this, ptr);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_store(T* ptr) const {
|
||||
cast_store_impl(ptr, *this);
|
||||
}
|
||||
};
|
||||
|
||||
SGL_HIP_INLINE void vec_t<nv_bfloat16, 4>::load(const nv_bfloat16* ptr) {
|
||||
data = *((uint2*)ptr);
|
||||
}
|
||||
|
||||
SGL_HIP_INLINE void vec_t<nv_bfloat16, 4>::store(nv_bfloat16* ptr) const {
|
||||
*((uint2*)ptr) = data;
|
||||
}
|
||||
|
||||
// nv_bfloat16 x 8 or more
|
||||
|
||||
template <size_t vec_size>
|
||||
struct vec_t<nv_bfloat16, vec_size> {
|
||||
uint4 data[vec_size / 8];
|
||||
|
||||
SGL_HIP_INLINE nv_bfloat16& operator[](size_t i) {
|
||||
return ((nv_bfloat16*)data)[i];
|
||||
}
|
||||
SGL_HIP_INLINE const nv_bfloat16& operator[](size_t i) const {
|
||||
return ((const nv_bfloat16*)data)[i];
|
||||
}
|
||||
SGL_HIP_INLINE nv_bfloat16* ptr() {
|
||||
return reinterpret_cast<nv_bfloat16*>(&data);
|
||||
}
|
||||
SGL_HIP_INLINE void load(const nv_bfloat16* ptr) {
|
||||
#pragma unoll
|
||||
for (size_t i = 0; i < vec_size / 8; ++i) {
|
||||
data[i] = ((uint4*)ptr)[i];
|
||||
}
|
||||
}
|
||||
SGL_HIP_INLINE void store(nv_bfloat16* ptr) const {
|
||||
#pragma unoll
|
||||
for (size_t i = 0; i < vec_size / 8; ++i) {
|
||||
((uint4*)ptr)[i] = data[i];
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_from(const vec_t<T, vec_size>& src) {
|
||||
cast_from_impl(*this, src);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_load(const T* ptr) {
|
||||
cast_load_impl(*this, ptr);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_store(T* ptr) const {
|
||||
cast_store_impl(ptr, *this);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace sgl_hip
|
||||
|
||||
#endif
|
||||
129
sgl-kernel/include/hip/impl/hip_vec_fp32_impl.h
Normal file
129
sgl-kernel/include/hip/impl/hip_vec_fp32_impl.h
Normal file
@@ -0,0 +1,129 @@
|
||||
#pragma once
|
||||
|
||||
#if USE_ROCM
|
||||
|
||||
#include <hip/hip_common.h>
|
||||
|
||||
// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491)
|
||||
|
||||
namespace sgl_hip {
|
||||
|
||||
template <>
|
||||
struct vec_t<float, 1> {
|
||||
float data;
|
||||
|
||||
SGL_HIP_INLINE float& operator[](size_t i) {
|
||||
return ((float*)(&data))[i];
|
||||
}
|
||||
SGL_HIP_INLINE const float& operator[](size_t i) const {
|
||||
return ((const float*)(&data))[i];
|
||||
}
|
||||
SGL_HIP_INLINE float* ptr() {
|
||||
return reinterpret_cast<float*>(&data);
|
||||
}
|
||||
SGL_HIP_INLINE void load(const float* ptr);
|
||||
SGL_HIP_INLINE void store(float* ptr) const;
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_from(const vec_t<T, 1>& src) {
|
||||
cast_from_impl(*this, src);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_load(const T* ptr) {
|
||||
cast_load_impl(*this, ptr);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_store(T* ptr) const {
|
||||
cast_store_impl(ptr, *this);
|
||||
}
|
||||
};
|
||||
|
||||
SGL_HIP_INLINE void vec_t<float, 1>::load(const float* ptr) {
|
||||
data = *ptr;
|
||||
}
|
||||
|
||||
SGL_HIP_INLINE void vec_t<float, 1>::store(float* ptr) const {
|
||||
*ptr = data;
|
||||
}
|
||||
|
||||
// float x 2
|
||||
|
||||
template <>
|
||||
struct vec_t<float, 2> {
|
||||
float2 data;
|
||||
|
||||
SGL_HIP_INLINE float& operator[](size_t i) {
|
||||
return ((float*)(&data))[i];
|
||||
}
|
||||
SGL_HIP_INLINE const float& operator[](size_t i) const {
|
||||
return ((const float*)(&data))[i];
|
||||
}
|
||||
SGL_HIP_INLINE float* ptr() {
|
||||
return reinterpret_cast<float*>(&data);
|
||||
}
|
||||
SGL_HIP_INLINE void load(const float* ptr);
|
||||
SGL_HIP_INLINE void store(float* ptr) const;
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_from(const vec_t<T, 2>& src) {
|
||||
cast_from_impl(*this, src);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_load(const T* ptr) {
|
||||
cast_load_impl(*this, ptr);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_store(T* ptr) const {
|
||||
cast_store_impl(ptr, *this);
|
||||
}
|
||||
};
|
||||
|
||||
SGL_HIP_INLINE void vec_t<float, 2>::load(const float* ptr) {
|
||||
data = *((float2*)ptr);
|
||||
}
|
||||
|
||||
SGL_HIP_INLINE void vec_t<float, 2>::store(float* ptr) const {
|
||||
*((float2*)ptr) = data;
|
||||
}
|
||||
|
||||
// float x 4 or more
|
||||
template <size_t vec_size>
|
||||
struct vec_t<float, vec_size> {
|
||||
float4 data[vec_size / 4];
|
||||
|
||||
SGL_HIP_INLINE float& operator[](size_t i) {
|
||||
return ((float*)(data))[i];
|
||||
}
|
||||
SGL_HIP_INLINE const float& operator[](size_t i) const {
|
||||
return ((const float*)(data))[i];
|
||||
}
|
||||
SGL_HIP_INLINE float* ptr() {
|
||||
return reinterpret_cast<float*>(&data);
|
||||
}
|
||||
SGL_HIP_INLINE void load(const float* ptr) {
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size / 4; ++i) {
|
||||
data[i] = ((float4*)ptr)[i];
|
||||
}
|
||||
}
|
||||
SGL_HIP_INLINE void store(float* ptr) const {
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size / 4; ++i) {
|
||||
((float4*)ptr)[i] = data[i];
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_from(const vec_t<T, vec_size>& src) {
|
||||
cast_from_impl(*this, src);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_load(const T* ptr) {
|
||||
cast_load_impl(*this, ptr);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_store(T* ptr) const {
|
||||
cast_store_impl(ptr, *this);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace sgl_hip
|
||||
|
||||
#endif
|
||||
172
sgl-kernel/include/hip/impl/hip_vec_half_impl.h
Normal file
172
sgl-kernel/include/hip/impl/hip_vec_half_impl.h
Normal file
@@ -0,0 +1,172 @@
|
||||
#pragma once
|
||||
|
||||
#if USE_ROCM
|
||||
|
||||
#include <hip/hip_common.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491)
|
||||
|
||||
using half = __half;
|
||||
using half2 = __half2;
|
||||
|
||||
namespace sgl_hip {
|
||||
|
||||
// half x 1
|
||||
template <>
|
||||
struct vec_t<half, 1> {
|
||||
half data;
|
||||
|
||||
SGL_HIP_INLINE half& operator[](size_t i) {
|
||||
return ((half*)(&data))[i];
|
||||
}
|
||||
SGL_HIP_INLINE const half& operator[](size_t i) const {
|
||||
return ((const half*)(&data))[i];
|
||||
}
|
||||
SGL_HIP_INLINE half* ptr() {
|
||||
return reinterpret_cast<half*>(&data);
|
||||
}
|
||||
SGL_HIP_INLINE void load(const half* ptr);
|
||||
SGL_HIP_INLINE void store(half* ptr) const;
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_from(const vec_t<T, 1>& src) {
|
||||
cast_from_impl(*this, src);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_load(const T* ptr) {
|
||||
cast_load_impl(*this, ptr);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_store(T* ptr) const {
|
||||
cast_store_impl(ptr, *this);
|
||||
}
|
||||
};
|
||||
|
||||
SGL_HIP_INLINE void vec_t<half, 1>::load(const half* ptr) {
|
||||
data = *ptr;
|
||||
}
|
||||
|
||||
SGL_HIP_INLINE void vec_t<half, 1>::store(half* ptr) const {
|
||||
*ptr = data;
|
||||
}
|
||||
|
||||
// half x 2
|
||||
template <>
|
||||
struct vec_t<half, 2> {
|
||||
half2 data;
|
||||
|
||||
SGL_HIP_INLINE half& operator[](size_t i) {
|
||||
return ((half*)(&data))[i];
|
||||
}
|
||||
SGL_HIP_INLINE const half& operator[](size_t i) const {
|
||||
return ((const half*)(&data))[i];
|
||||
}
|
||||
SGL_HIP_INLINE half* ptr() {
|
||||
return reinterpret_cast<half*>(&data);
|
||||
}
|
||||
SGL_HIP_INLINE void load(const half* ptr);
|
||||
SGL_HIP_INLINE void store(half* ptr) const;
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_from(const vec_t<T, 2>& src) {
|
||||
cast_from_impl(*this, src);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_load(const T* ptr) {
|
||||
cast_load_impl(*this, ptr);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_store(T* ptr) const {
|
||||
cast_store_impl(ptr, *this);
|
||||
}
|
||||
};
|
||||
|
||||
SGL_HIP_INLINE void vec_t<half, 2>::load(const half* ptr) {
|
||||
data = *((half2*)ptr);
|
||||
}
|
||||
|
||||
SGL_HIP_INLINE void vec_t<half, 2>::store(half* ptr) const {
|
||||
*((half2*)ptr) = data;
|
||||
}
|
||||
|
||||
// half x 4
|
||||
|
||||
template <>
|
||||
struct vec_t<half, 4> {
|
||||
uint2 data;
|
||||
|
||||
SGL_HIP_INLINE half& operator[](size_t i) {
|
||||
return ((half*)(&data))[i];
|
||||
}
|
||||
SGL_HIP_INLINE const half& operator[](size_t i) const {
|
||||
return ((const half*)(&data))[i];
|
||||
}
|
||||
SGL_HIP_INLINE half* ptr() {
|
||||
return reinterpret_cast<half*>(&data);
|
||||
}
|
||||
SGL_HIP_INLINE void load(const half* ptr);
|
||||
SGL_HIP_INLINE void store(half* ptr) const;
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_from(const vec_t<T, 4>& src) {
|
||||
cast_from_impl(*this, src);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_load(const T* ptr) {
|
||||
cast_load_impl(*this, ptr);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_store(T* ptr) const {
|
||||
cast_store_impl(ptr, *this);
|
||||
}
|
||||
};
|
||||
|
||||
SGL_HIP_INLINE void vec_t<half, 4>::load(const half* ptr) {
|
||||
data = *((uint2*)ptr);
|
||||
}
|
||||
|
||||
SGL_HIP_INLINE void vec_t<half, 4>::store(half* ptr) const {
|
||||
*((uint2*)ptr) = data;
|
||||
}
|
||||
|
||||
// half x 8 or more
|
||||
|
||||
template <size_t vec_size>
|
||||
struct vec_t<half, vec_size> {
|
||||
uint4 data[vec_size / 8];
|
||||
|
||||
SGL_HIP_INLINE half& operator[](size_t i) {
|
||||
return ((half*)data)[i];
|
||||
}
|
||||
SGL_HIP_INLINE const half& operator[](size_t i) const {
|
||||
return ((const half*)data)[i];
|
||||
}
|
||||
SGL_HIP_INLINE half* ptr() {
|
||||
return reinterpret_cast<half*>(&data);
|
||||
}
|
||||
SGL_HIP_INLINE void load(const half* ptr) {
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size / 8; ++i) {
|
||||
data[i] = ((uint4*)ptr)[i];
|
||||
}
|
||||
}
|
||||
SGL_HIP_INLINE void store(half* ptr) const {
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size / 8; ++i) {
|
||||
((uint4*)ptr)[i] = data[i];
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_from(const vec_t<T, vec_size>& src) {
|
||||
cast_from_impl(*this, src);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_load(const T* ptr) {
|
||||
cast_load_impl(*this, ptr);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_store(T* ptr) const {
|
||||
cast_store_impl(ptr, *this);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace sgl_hip
|
||||
#endif
|
||||
20
sgl-kernel/include/pytorch_extension_utils_rocm.h
Normal file
20
sgl-kernel/include/pytorch_extension_utils_rocm.h
Normal file
@@ -0,0 +1,20 @@
|
||||
#include <torch/library.h>
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
||||
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_LAST_DIM_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimension")
|
||||
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_LAST_DIM_CONTIGUOUS(x)
|
||||
|
||||
#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
|
||||
|
||||
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
|
||||
|
||||
#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b)
|
||||
330
sgl-kernel/include/scalar_type.hpp
Normal file
330
sgl-kernel/include/scalar_type.hpp
Normal file
@@ -0,0 +1,330 @@
|
||||
#pragma once
|
||||
|
||||
// For TORCH_CHECK
|
||||
#include <torch/library.h>
|
||||
|
||||
namespace sglang {
|
||||
|
||||
//
|
||||
// ScalarType can represent a wide range of floating point and integer types,
|
||||
// in particular it can be used to represent sub-byte data types (something
|
||||
// that torch.dtype currently does not support).
|
||||
//
|
||||
// The type definitions on the Python side can be found in: vllm/scalar_type.py
|
||||
// these type definitions should be kept up to date with any Python API changes
|
||||
// here.
|
||||
//
|
||||
class ScalarType {
|
||||
public:
|
||||
enum NanRepr : uint8_t {
|
||||
NAN_NONE = 0, // nans are not supported
|
||||
NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s
|
||||
NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s
|
||||
|
||||
NAN_REPR_ID_MAX
|
||||
};
|
||||
|
||||
constexpr ScalarType(
|
||||
uint8_t exponent,
|
||||
uint8_t mantissa,
|
||||
bool signed_,
|
||||
int32_t bias,
|
||||
bool finite_values_only = false,
|
||||
NanRepr nan_repr = NAN_IEEE_754)
|
||||
: exponent(exponent),
|
||||
mantissa(mantissa),
|
||||
signed_(signed_),
|
||||
bias(bias),
|
||||
finite_values_only(finite_values_only),
|
||||
nan_repr(nan_repr) {};
|
||||
|
||||
static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) {
|
||||
return ScalarType(0, size_bits - 1, true, bias);
|
||||
}
|
||||
|
||||
static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) {
|
||||
return ScalarType(0, size_bits, false, bias);
|
||||
}
|
||||
|
||||
// IEEE 754 compliant floating point type
|
||||
static constexpr ScalarType float_IEEE754(uint8_t exponent, uint8_t mantissa) {
|
||||
TORCH_CHECK(mantissa > 0 && exponent > 0);
|
||||
return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754);
|
||||
}
|
||||
|
||||
// IEEE 754 non-compliant floating point type
|
||||
static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa, bool finite_values_only, NanRepr nan_repr) {
|
||||
TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr");
|
||||
TORCH_CHECK(mantissa > 0 && exponent > 0);
|
||||
TORCH_CHECK(
|
||||
nan_repr != NAN_IEEE_754,
|
||||
"use `float_IEEE754` constructor for floating point types that "
|
||||
"follow IEEE 754 conventions");
|
||||
return ScalarType(exponent, mantissa, true, 0, finite_values_only, nan_repr);
|
||||
}
|
||||
|
||||
uint8_t const exponent; // size of the exponent field (0 for integer types)
|
||||
uint8_t const mantissa; // size of the mantissa field (size of the integer
|
||||
// excluding the sign bit for integer types)
|
||||
bool const signed_; // flag if the type supports negative numbers (i.e. has a
|
||||
// sign bit)
|
||||
int32_t const bias; // stored values equal value + bias,
|
||||
// used for quantized type
|
||||
|
||||
// Extra Floating point info
|
||||
bool const finite_values_only; // i.e. no +/-inf if true
|
||||
NanRepr const nan_repr; // how NaNs are represented
|
||||
// (not applicable for integer types)
|
||||
|
||||
using Id = int64_t;
|
||||
|
||||
private:
|
||||
// Field size in id
|
||||
template <typename T_>
|
||||
static constexpr size_t member_id_field_width() {
|
||||
using T = std::decay_t<T_>;
|
||||
return std::is_same_v<T, bool> ? 1 : sizeof(T) * 8;
|
||||
}
|
||||
|
||||
template <typename Fn, typename Init, typename Member, typename... Rest>
|
||||
static constexpr auto reduce_members_helper(Fn f, Init val, Member member, Rest... rest) {
|
||||
auto new_val = f(val, member);
|
||||
if constexpr (sizeof...(rest) > 0) {
|
||||
return reduce_members_helper(f, new_val, rest...);
|
||||
} else {
|
||||
return new_val;
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Fn, typename Init>
|
||||
constexpr auto reduce_members(Fn f, Init init) const {
|
||||
// Should be in constructor order for `from_id`
|
||||
return reduce_members_helper(f, init, exponent, mantissa, signed_, bias, finite_values_only, nan_repr);
|
||||
};
|
||||
|
||||
template <typename Fn, typename Init>
|
||||
static constexpr auto reduce_member_types(Fn f, Init init) {
|
||||
constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE);
|
||||
return dummy_type.reduce_members(f, init);
|
||||
};
|
||||
|
||||
static constexpr auto id_size_bits() {
|
||||
return reduce_member_types(
|
||||
[](int acc, auto member) -> int { return acc + member_id_field_width<decltype(member)>(); }, 0);
|
||||
}
|
||||
|
||||
public:
|
||||
// unique id for this scalar type that can be computed at compile time for
|
||||
// c++17 template specialization this is not needed once we migrate to
|
||||
// c++20 and can pass literal classes as template parameters
|
||||
constexpr Id id() const {
|
||||
static_assert(id_size_bits() <= sizeof(Id) * 8, "ScalarType id is too large to be stored");
|
||||
|
||||
auto or_and_advance = [](std::pair<Id, uint32_t> result, auto member) -> std::pair<Id, uint32_t> {
|
||||
auto [id, bit_offset] = result;
|
||||
auto constexpr bits = member_id_field_width<decltype(member)>();
|
||||
return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1)) << bit_offset, bit_offset + bits};
|
||||
};
|
||||
return reduce_members(or_and_advance, std::pair<Id, uint32_t>{}).first;
|
||||
}
|
||||
|
||||
// create a ScalarType from an id, for c++17 template specialization,
|
||||
// this is not needed once we migrate to c++20 and can pass literal
|
||||
// classes as template parameters
|
||||
static constexpr ScalarType from_id(Id id) {
|
||||
auto extract_and_advance = [id](auto result, auto member) {
|
||||
using T = decltype(member);
|
||||
auto [tuple, bit_offset] = result;
|
||||
auto constexpr bits = member_id_field_width<T>();
|
||||
auto extracted_val = static_cast<T>((int64_t(id) >> bit_offset) & ((uint64_t(1) << bits) - 1));
|
||||
auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val));
|
||||
return std::pair<decltype(new_tuple), int>{new_tuple, bit_offset + bits};
|
||||
};
|
||||
|
||||
auto [tuple_args, _] = reduce_member_types(extract_and_advance, std::pair<std::tuple<>, int>{});
|
||||
return std::apply([](auto... args) { return ScalarType(args...); }, tuple_args);
|
||||
}
|
||||
|
||||
constexpr int64_t size_bits() const {
|
||||
return mantissa + exponent + is_signed();
|
||||
}
|
||||
constexpr bool is_signed() const {
|
||||
return signed_;
|
||||
}
|
||||
constexpr bool is_integer() const {
|
||||
return exponent == 0;
|
||||
}
|
||||
constexpr bool is_floating_point() const {
|
||||
return exponent > 0;
|
||||
}
|
||||
constexpr bool is_ieee_754() const {
|
||||
return is_floating_point() && finite_values_only == false && nan_repr == NAN_IEEE_754;
|
||||
}
|
||||
constexpr bool has_nans() const {
|
||||
return is_floating_point() && nan_repr != NAN_NONE;
|
||||
}
|
||||
constexpr bool has_infs() const {
|
||||
return is_floating_point() && finite_values_only == false;
|
||||
}
|
||||
constexpr bool has_bias() const {
|
||||
return bias != 0;
|
||||
}
|
||||
|
||||
private:
|
||||
double _floating_point_max() const {
|
||||
TORCH_CHECK(mantissa <= 52 && exponent <= 11, "Cannot represent max/min as a double for type ", str());
|
||||
|
||||
uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1;
|
||||
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) {
|
||||
max_mantissa -= 1;
|
||||
}
|
||||
|
||||
uint64_t max_exponent = (uint64_t(1) << exponent) - 2;
|
||||
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) {
|
||||
TORCH_CHECK(exponent < 11, "Cannot represent max/min as a double for type ", str());
|
||||
max_exponent += 1;
|
||||
}
|
||||
|
||||
// adjust the exponent to match that of a double
|
||||
// for now we assume the exponent bias is the standard 2^(e-1) -1, (where e
|
||||
// is the exponent bits), there is some precedent for non-standard biases,
|
||||
// example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes
|
||||
// but to avoid premature over complication we are just assuming the
|
||||
// standard exponent bias until there is a need to support non-standard
|
||||
// biases
|
||||
uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1;
|
||||
uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11
|
||||
|
||||
uint64_t max_exponent_double = max_exponent - exponent_bias + exponent_bias_double;
|
||||
|
||||
// shift the mantissa into the position for a double and
|
||||
// the exponent
|
||||
uint64_t double_raw = (max_mantissa << (52 - mantissa)) | (max_exponent_double << 52);
|
||||
|
||||
return *reinterpret_cast<double*>(&double_raw);
|
||||
}
|
||||
|
||||
constexpr std::variant<int64_t, double> _raw_max() const {
|
||||
if (is_floating_point()) {
|
||||
return {_floating_point_max()};
|
||||
} else {
|
||||
TORCH_CHECK(size_bits() < 64 || size_bits() == 64 && is_signed(), "Cannot represent max as a int64_t");
|
||||
return {(int64_t(1) << mantissa) - 1};
|
||||
}
|
||||
}
|
||||
|
||||
constexpr std::variant<int64_t, double> _raw_min() const {
|
||||
if (is_floating_point()) {
|
||||
TORCH_CHECK(is_signed(), "We currently assume all floating point types are signed");
|
||||
constexpr uint64_t sign_bit_double = (uint64_t(1) << 63);
|
||||
|
||||
double max = _floating_point_max();
|
||||
uint64_t max_raw = *reinterpret_cast<uint64_t*>(&max);
|
||||
uint64_t min_raw = max_raw | sign_bit_double;
|
||||
return {*reinterpret_cast<double*>(&min_raw)};
|
||||
} else {
|
||||
TORCH_CHECK(!is_signed() || size_bits() <= 64, "Cannot represent min as a int64_t");
|
||||
if (is_signed()) {
|
||||
// set the top bit to 1 (i.e. INT64_MIN) and the rest to 0
|
||||
// then perform an arithmetic shift right to set all the bits above
|
||||
// (size_bits() - 1) to 1
|
||||
return {INT64_MIN >> (64 - size_bits())};
|
||||
} else {
|
||||
return {int64_t(0)};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
// Max representable value for this scalar type.
|
||||
// (accounting for bias if there is one)
|
||||
constexpr std::variant<int64_t, double> max() const {
|
||||
return std::visit([this](auto x) -> std::variant<int64_t, double> { return {x - bias}; }, _raw_max());
|
||||
}
|
||||
|
||||
// Min representable value for this scalar type.
|
||||
// (accounting for bias if there is one)
|
||||
constexpr std::variant<int64_t, double> min() const {
|
||||
return std::visit([this](auto x) -> std::variant<int64_t, double> { return {x - bias}; }, _raw_min());
|
||||
}
|
||||
|
||||
std::string str() const {
|
||||
/* naming generally follows: https://github.com/jax-ml/ml_dtypes
|
||||
* for floating point types (leading f) the scheme is:
|
||||
* `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
||||
* flags:
|
||||
* - no-flags: means it follows IEEE 754 conventions
|
||||
* - f: means finite values only (no infinities)
|
||||
* - n: means nans are supported (non-standard encoding)
|
||||
* for integer types the scheme is:
|
||||
* `[u]int<size_bits>[b<bias>]`
|
||||
* - if bias is not present it means its zero
|
||||
*/
|
||||
if (is_floating_point()) {
|
||||
auto ret =
|
||||
"float" + std::to_string(size_bits()) + "_e" + std::to_string(exponent) + "m" + std::to_string(mantissa);
|
||||
if (!is_ieee_754()) {
|
||||
if (finite_values_only) {
|
||||
ret += "f";
|
||||
}
|
||||
if (nan_repr != NAN_NONE) {
|
||||
ret += "n";
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
} else {
|
||||
auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits());
|
||||
if (has_bias()) {
|
||||
ret += "b" + std::to_string(bias);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr bool operator==(ScalarType const& other) const {
|
||||
return mantissa == other.mantissa && exponent == other.exponent && bias == other.bias && signed_ == other.signed_ &&
|
||||
finite_values_only == other.finite_values_only && nan_repr == other.nan_repr;
|
||||
}
|
||||
};
|
||||
|
||||
using ScalarTypeId = ScalarType::Id;
|
||||
|
||||
// "rust style" names generally following:
|
||||
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70
|
||||
static inline constexpr auto kS4 = ScalarType::int_(4);
|
||||
static inline constexpr auto kU4 = ScalarType::uint(4);
|
||||
static inline constexpr auto kU4B8 = ScalarType::uint(4, 8);
|
||||
static inline constexpr auto kS8 = ScalarType::int_(8);
|
||||
static inline constexpr auto kU8 = ScalarType::uint(8);
|
||||
static inline constexpr auto kU8B128 = ScalarType::uint(8, 128);
|
||||
|
||||
static inline constexpr auto kFE2M1f = ScalarType::float_(2, 1, true, ScalarType::NAN_NONE);
|
||||
static inline constexpr auto kFE3M2f = ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
|
||||
static inline constexpr auto kFE4M3fn = ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
|
||||
static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2);
|
||||
static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7);
|
||||
static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10);
|
||||
|
||||
// Fixed width style names, generally following:
|
||||
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57
|
||||
static inline constexpr auto kInt4 = kS4;
|
||||
static inline constexpr auto kUint4 = kU4;
|
||||
static inline constexpr auto kUint4b8 = kU4B8;
|
||||
static inline constexpr auto kInt8 = kS8;
|
||||
static inline constexpr auto kUint8 = kU8;
|
||||
static inline constexpr auto kUint8b128 = kU8B128;
|
||||
|
||||
static inline constexpr auto kFloat4_e2m1f = kFE2M1f;
|
||||
static inline constexpr auto kFloat6_e3m2f = kFE3M2f;
|
||||
static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn;
|
||||
static inline constexpr auto kFloat8_e5m2 = kFE5M2;
|
||||
static inline constexpr auto kFloat16_e8m7 = kFE8M7;
|
||||
static inline constexpr auto kFloat16_e5m10 = kFE5M10;
|
||||
|
||||
// colloquial names
|
||||
static inline constexpr auto kHalf = kFE5M10;
|
||||
static inline constexpr auto kFloat16 = kHalf;
|
||||
static inline constexpr auto kBFloat16 = kFE8M7;
|
||||
|
||||
static inline constexpr auto kFloat16Id = kFloat16.id();
|
||||
}; // namespace sglang
|
||||
86
sgl-kernel/include/sgl_flash_kernel_ops.h
Normal file
86
sgl-kernel/include/sgl_flash_kernel_ops.h
Normal file
@@ -0,0 +1,86 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <Python.h>
|
||||
#include <torch/library.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "sgl_kernel_torch_shim.h"
|
||||
|
||||
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
|
||||
|
||||
#define _CONCAT(A, B) A##B
|
||||
#define CONCAT(A, B) _CONCAT(A, B)
|
||||
|
||||
#define _STRINGIFY(A) #A
|
||||
#define STRINGIFY(A) _STRINGIFY(A)
|
||||
|
||||
#define REGISTER_EXTENSION(NAME) \
|
||||
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
|
||||
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, STRINGIFY(NAME), nullptr, 0, nullptr}; \
|
||||
return PyModule_Create(&module); \
|
||||
}
|
||||
|
||||
/*
|
||||
* From flash-attention
|
||||
*/
|
||||
std::vector<at::Tensor> mha_fwd(
|
||||
at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
|
||||
const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size,
|
||||
// h_k, d) if there is page_table.
|
||||
const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages,
|
||||
// page_size, h_k, dv) if there is page_table.
|
||||
std::optional<const at::Tensor>&
|
||||
k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
|
||||
std::optional<const at::Tensor>&
|
||||
v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
|
||||
std::optional<const at::Tensor>& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
|
||||
std::optional<at::Tensor>& out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
|
||||
std::optional<const at::Tensor>& cu_seqlens_q_, // b+1
|
||||
std::optional<const at::Tensor>& cu_seqlens_k_, // b+1
|
||||
std::optional<const at::Tensor>& cu_seqlens_k_new_, // b+1
|
||||
std::optional<const at::Tensor>&
|
||||
seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
|
||||
std::optional<const at::Tensor>&
|
||||
seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
|
||||
std::optional<int> max_seqlen_q_,
|
||||
// TODO: check if we need max_seqlen_k
|
||||
std::optional<int> max_seqlen_k_,
|
||||
std::optional<const at::Tensor>& page_table_, // (b_k, max_num_pages_per_seq)
|
||||
std::optional<const at::Tensor>& kv_batch_idx_, // b. indices to index into the KV cache
|
||||
std::optional<const at::Tensor>& leftpad_k_, // b
|
||||
std::optional<const at::Tensor>& rotary_cos_, // seqlen_ro x (rotary_dim / 2)
|
||||
std::optional<const at::Tensor>& rotary_sin_, // seqlen_ro x (rotary_dim / 2)
|
||||
std::optional<const at::Tensor>& seqlens_rotary_, // b
|
||||
std::optional<at::Tensor>& q_descale_, // (b, h_k), not (b, h)
|
||||
std::optional<at::Tensor>& k_descale_, // (b, h_k)
|
||||
std::optional<at::Tensor>& v_descale_, // (b, h_k)
|
||||
float const softmax_scale,
|
||||
bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
float const softcap,
|
||||
bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
|
||||
std::optional<at::Tensor>& scheduler_metadata_, // (b + 1)
|
||||
int num_splits,
|
||||
std::optional<bool> pack_gqa_,
|
||||
int const sm_margin,
|
||||
std::optional<const at::Tensor>& sinks_);
|
||||
758
sgl-kernel/include/sgl_kernel_ops.h
Normal file
758
sgl-kernel/include/sgl_kernel_ops.h
Normal file
@@ -0,0 +1,758 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <Python.h>
|
||||
#include <torch/all.h>
|
||||
#include <torch/library.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "scalar_type.hpp"
|
||||
|
||||
#define _CONCAT(A, B) A##B
|
||||
#define CONCAT(A, B) _CONCAT(A, B)
|
||||
|
||||
#define _STRINGIFY(A) #A
|
||||
#define STRINGIFY(A) _STRINGIFY(A)
|
||||
|
||||
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
|
||||
|
||||
#define REGISTER_EXTENSION(NAME) \
|
||||
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
|
||||
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, STRINGIFY(NAME), nullptr, 0, nullptr}; \
|
||||
return PyModule_Create(&module); \
|
||||
}
|
||||
|
||||
using fptr_t = int64_t;
|
||||
|
||||
/*
|
||||
* From csrc/allreduce
|
||||
*/
|
||||
#ifdef USE_ROCM
|
||||
// ROCM custom allreduce
|
||||
fptr_t init_custom_ar(
|
||||
torch::Tensor& meta,
|
||||
torch::Tensor& rank_data,
|
||||
const std::vector<std::string>& handles,
|
||||
const std::vector<int64_t>& offsets,
|
||||
int64_t rank,
|
||||
bool full_nvlink);
|
||||
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
|
||||
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, torch::Tensor& out);
|
||||
void dispose(fptr_t _fa);
|
||||
int64_t meta_size();
|
||||
void register_buffer(
|
||||
fptr_t _fa, torch::Tensor& t, const std::vector<std::string>& handles, const std::vector<int64_t>& offsets);
|
||||
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
|
||||
void register_graph_buffers(
|
||||
fptr_t _fa, const std::vector<std::string>& handles, const std::vector<std::vector<int64_t>>& offsets);
|
||||
torch::Tensor allocate_meta_buffer(int64_t size);
|
||||
torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp);
|
||||
// quick allreduce
|
||||
fptr_t init_custom_qr(int64_t rank, int64_t world_size, std::optional<int64_t> qr_max_size = std::nullopt);
|
||||
void qr_destroy(fptr_t _fa);
|
||||
torch::Tensor qr_get_handle(fptr_t _fa);
|
||||
void qr_open_handles(fptr_t _fa, const std::vector<torch::Tensor>& handles);
|
||||
void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, int64_t quant_level, bool cast_bf2half = false);
|
||||
int64_t qr_max_size();
|
||||
#else
|
||||
// custom allreduce
|
||||
fptr_t
|
||||
init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs, torch::Tensor& rank_data, int64_t rank, bool full_nvlink);
|
||||
void dispose(fptr_t _fa);
|
||||
int64_t meta_size();
|
||||
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes);
|
||||
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
|
||||
void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs);
|
||||
void register_graph_buffers(
|
||||
fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets);
|
||||
|
||||
// mscclpp
|
||||
torch::Tensor mscclpp_generate_unique_id();
|
||||
fptr_t mscclpp_init_context(
|
||||
const torch::Tensor& unique_id,
|
||||
const int64_t rank,
|
||||
const int64_t world_size,
|
||||
torch::Tensor& scratch,
|
||||
torch::Tensor& put_buffer,
|
||||
const int64_t nranks_per_node,
|
||||
const std::vector<int64_t>& rank_to_node,
|
||||
const std::vector<int64_t>& rank_to_ib,
|
||||
const int64_t context_selection);
|
||||
void mscclpp_allreduce(fptr_t _context, torch::Tensor& inp, torch::Tensor& out, int64_t nthreads, int64_t nblocks);
|
||||
#endif
|
||||
|
||||
/*
|
||||
* From csrc/attention
|
||||
*/
|
||||
void lightning_attention_decode(
|
||||
const torch::Tensor& q,
|
||||
const torch::Tensor& k,
|
||||
const torch::Tensor& v,
|
||||
const torch::Tensor& past_kv,
|
||||
const torch::Tensor& slope,
|
||||
torch::Tensor output,
|
||||
torch::Tensor new_kv);
|
||||
void merge_state(
|
||||
at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged);
|
||||
void merge_state_v2(
|
||||
at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged);
|
||||
void cutlass_mla_decode(
|
||||
torch::Tensor const& out,
|
||||
torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe,
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table,
|
||||
torch::Tensor const& workspace,
|
||||
double sm_scale,
|
||||
int64_t num_kv_splits = 1 /* Set to 1 to avoid cuda_graph issue by default. */);
|
||||
int64_t cutlass_mla_get_workspace_size(
|
||||
int64_t max_seq_len,
|
||||
int64_t num_batches,
|
||||
int64_t sm_count = 0,
|
||||
int64_t num_kv_splits = 1 /* Set to 1 to avoid cuda_graph issue by default. */);
|
||||
|
||||
/*
|
||||
* From csrc/elementwise
|
||||
*/
|
||||
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl);
|
||||
void sgl_fused_add_rmsnorm(
|
||||
torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps, bool enable_pdl);
|
||||
void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl);
|
||||
void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, bool enable_pdl);
|
||||
void silu_and_mul(at::Tensor& out, at::Tensor& input);
|
||||
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input);
|
||||
void gelu_and_mul(at::Tensor& out, at::Tensor& input);
|
||||
|
||||
void apply_rope_pos_ids_cos_sin_cache(
|
||||
at::Tensor q,
|
||||
at::Tensor k,
|
||||
at::Tensor q_rope,
|
||||
at::Tensor k_rope,
|
||||
at::Tensor cos_sin_cache,
|
||||
at::Tensor pos_ids,
|
||||
bool interleave,
|
||||
bool enable_pdl,
|
||||
int64_t cuda_stream,
|
||||
const std::optional<at::Tensor>& v,
|
||||
const std::optional<at::Tensor>& k_buffer,
|
||||
const std::optional<at::Tensor>& v_buffer,
|
||||
const std::optional<at::Tensor>& kv_cache_loc);
|
||||
|
||||
void downcast_fp8(
|
||||
at::Tensor& k,
|
||||
at::Tensor& v,
|
||||
at::Tensor& k_out,
|
||||
at::Tensor& v_out,
|
||||
at::Tensor& k_scale,
|
||||
at::Tensor& v_scale,
|
||||
at::Tensor& loc,
|
||||
int64_t mult,
|
||||
int64_t offset,
|
||||
int64_t cuda_stream);
|
||||
|
||||
#ifdef USE_ROCM
|
||||
void gelu_quick(at::Tensor& out, const at::Tensor& input);
|
||||
#endif
|
||||
|
||||
/*
|
||||
* From csrc/gemm
|
||||
*/
|
||||
torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch::Tensor qzeros);
|
||||
void cutlass_scaled_fp4_mm(
|
||||
torch::Tensor& D,
|
||||
torch::Tensor const& A,
|
||||
torch::Tensor const& B,
|
||||
torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha);
|
||||
torch::Tensor int8_scaled_mm(
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const torch::Dtype& out_dtype,
|
||||
const c10::optional<torch::Tensor>& bias);
|
||||
torch::Tensor fp8_scaled_mm(
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const torch::Dtype& out_dtype,
|
||||
const c10::optional<torch::Tensor>& bias);
|
||||
torch::Tensor fp8_blockwise_scaled_mm(
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const torch::Dtype& out_dtype);
|
||||
void scaled_fp4_quant(
|
||||
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_scale, torch::Tensor const& input_scale);
|
||||
void sgl_per_token_group_quant_fp8(
|
||||
at::Tensor input,
|
||||
at::Tensor output_q,
|
||||
at::Tensor output_s,
|
||||
int64_t group_size,
|
||||
double eps,
|
||||
double fp8_min,
|
||||
double fp8_max,
|
||||
bool scale_ue8m0);
|
||||
void sgl_per_token_group_quant_int8(
|
||||
at::Tensor input,
|
||||
at::Tensor output_q,
|
||||
at::Tensor output_s,
|
||||
int64_t group_size,
|
||||
double eps,
|
||||
double int8_min,
|
||||
double int8_max);
|
||||
void sgl_per_tensor_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, bool is_static);
|
||||
void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s);
|
||||
void bmm_fp8(
|
||||
at::Tensor A,
|
||||
at::Tensor B,
|
||||
at::Tensor D,
|
||||
at::Tensor A_scale,
|
||||
at::Tensor B_scale,
|
||||
at::Tensor workspace_buffer,
|
||||
int64_t cublas_handle,
|
||||
int64_t cuda_stream);
|
||||
void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a, const torch::Tensor& mat_b);
|
||||
void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch::Tensor const& mat_b);
|
||||
|
||||
torch::Tensor gptq_marlin_gemm(
|
||||
torch::Tensor& a,
|
||||
std::optional<torch::Tensor> c_or_none,
|
||||
torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales,
|
||||
std::optional<torch::Tensor> const& global_scale_or_none,
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
std::optional<torch::Tensor> const& perm_or_none,
|
||||
torch::Tensor& workspace,
|
||||
sglang::ScalarTypeId const& b_q_type_id,
|
||||
int64_t size_m,
|
||||
int64_t size_n,
|
||||
int64_t size_k,
|
||||
bool is_k_full,
|
||||
bool use_atomic_add,
|
||||
bool use_fp32_reduce,
|
||||
bool is_zp_float);
|
||||
|
||||
torch::Tensor gptq_gemm(
|
||||
torch::Tensor a,
|
||||
torch::Tensor b_q_weight,
|
||||
torch::Tensor b_gptq_qzeros,
|
||||
torch::Tensor b_gptq_scales,
|
||||
torch::Tensor b_g_idx,
|
||||
bool use_shuffle,
|
||||
int64_t bit);
|
||||
|
||||
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
|
||||
|
||||
torch::Tensor
|
||||
gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits);
|
||||
|
||||
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits);
|
||||
|
||||
/*
|
||||
* From csrc/moe
|
||||
*/
|
||||
void moe_align_block_size(
|
||||
torch::Tensor topk_ids,
|
||||
int64_t num_experts,
|
||||
int64_t block_size,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad,
|
||||
torch::Tensor cumsum_buffer,
|
||||
bool pad_sorted_token_ids);
|
||||
|
||||
void topk_softmax(
|
||||
torch::Tensor& topk_weights, torch::Tensor& topk_indices, torch::Tensor& gating_output, bool renormalize);
|
||||
|
||||
std::vector<at::Tensor> moe_fused_gate(
|
||||
at::Tensor& input,
|
||||
at::Tensor& bias,
|
||||
int64_t num_expert_group,
|
||||
int64_t topk_group,
|
||||
int64_t topk,
|
||||
int64_t num_fused_shared_experts,
|
||||
double routed_scaling_factor,
|
||||
bool apply_routed_scaling_factor_on_output);
|
||||
|
||||
void fp8_blockwise_scaled_grouped_mm(
|
||||
torch::Tensor& output,
|
||||
torch::Tensor& a_ptrs,
|
||||
torch::Tensor& b_ptrs,
|
||||
torch::Tensor& out_ptrs,
|
||||
torch::Tensor& a_scales_ptrs,
|
||||
torch::Tensor& b_scales_ptrs,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const torch::Tensor& stride_a,
|
||||
const torch::Tensor& stride_b,
|
||||
const torch::Tensor& stride_c,
|
||||
const torch::Tensor& layout_sfa,
|
||||
const torch::Tensor& layout_sfb,
|
||||
const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& workspace);
|
||||
|
||||
void prepare_moe_input(
|
||||
const torch::Tensor& topk_ids,
|
||||
torch::Tensor& expert_offsets,
|
||||
const std::optional<torch::Tensor>& blockscale_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation,
|
||||
torch::Tensor& output_permutation,
|
||||
const int64_t num_experts,
|
||||
const int64_t n,
|
||||
const int64_t k);
|
||||
|
||||
void ep_moe_pre_reorder(
|
||||
torch::Tensor input,
|
||||
torch::Tensor gateup_input,
|
||||
torch::Tensor src2dst,
|
||||
torch::Tensor topk_ids,
|
||||
torch::Tensor a1_scales,
|
||||
int64_t start_expert_id,
|
||||
int64_t end_expert_id,
|
||||
int64_t topk,
|
||||
bool use_per_token_if_dynamic);
|
||||
|
||||
void ep_moe_silu_and_mul(
|
||||
torch::Tensor gateup_output,
|
||||
torch::Tensor down_input,
|
||||
torch::Tensor reorder_topk_ids,
|
||||
torch::Tensor scales,
|
||||
int64_t start_expert_id,
|
||||
int64_t end_expert_id);
|
||||
|
||||
void ep_moe_post_reorder(
|
||||
torch::Tensor down_output,
|
||||
torch::Tensor output,
|
||||
torch::Tensor src2dst,
|
||||
torch::Tensor topk_ids,
|
||||
torch::Tensor topk_weights,
|
||||
int64_t start_expert_id,
|
||||
int64_t end_expert_id,
|
||||
int64_t topk);
|
||||
|
||||
void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor);
|
||||
|
||||
void apply_shuffle_mul_sum(
|
||||
const torch::Tensor& input,
|
||||
torch::Tensor& output,
|
||||
const torch::Tensor& permutation,
|
||||
const std::optional<torch::Tensor>& factors);
|
||||
|
||||
void cutlass_fp4_group_mm(
|
||||
torch::Tensor& output,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& a_blockscale,
|
||||
const torch::Tensor& b_blockscales,
|
||||
const torch::Tensor& alphas,
|
||||
const torch::Tensor& ab_strides,
|
||||
const torch::Tensor& c_strides,
|
||||
const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& sf_offsets);
|
||||
|
||||
void scaled_fp4_experts_quant(
|
||||
torch::Tensor& output,
|
||||
torch::Tensor& output_scale,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts);
|
||||
|
||||
void silu_and_mul_scaled_fp4_experts_quant(
|
||||
torch::Tensor& output,
|
||||
torch::Tensor& output_scale,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& mask,
|
||||
bool use_silu_and_mul);
|
||||
/*
|
||||
* From csrc/moe/cutlass_moe/w4a8
|
||||
*/
|
||||
void get_cutlass_w4a8_moe_mm_data(
|
||||
const torch::Tensor& topk_ids,
|
||||
torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation,
|
||||
torch::Tensor& output_permutation,
|
||||
const int64_t num_experts,
|
||||
const int64_t n,
|
||||
const int64_t k);
|
||||
|
||||
void cutlass_w4a8_moe_mm(
|
||||
torch::Tensor& d_tensors,
|
||||
torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes,
|
||||
torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides,
|
||||
torch::Tensor const& d_strides,
|
||||
torch::Tensor const& s_strides,
|
||||
int64_t chunk_size,
|
||||
int64_t topk);
|
||||
|
||||
torch::Tensor moe_wna16_marlin_gemm(
|
||||
torch::Tensor& a,
|
||||
std::optional<torch::Tensor> const& c_or_none,
|
||||
torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales,
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
std::optional<torch::Tensor> const& perm_or_none,
|
||||
torch::Tensor& workspace,
|
||||
torch::Tensor& sorted_token_ids,
|
||||
torch::Tensor& expert_ids,
|
||||
torch::Tensor& num_tokens_past_padded,
|
||||
torch::Tensor& topk_weights,
|
||||
int64_t moe_block_size,
|
||||
int64_t top_k,
|
||||
bool mul_topk_weights,
|
||||
bool is_ep,
|
||||
sglang::ScalarTypeId const& b_q_type_id,
|
||||
int64_t size_m,
|
||||
int64_t size_n,
|
||||
int64_t size_k,
|
||||
bool is_k_full,
|
||||
bool use_atomic_add,
|
||||
bool use_fp32_reduce,
|
||||
bool is_zp_float);
|
||||
|
||||
/*
|
||||
* From csrc/speculative
|
||||
*/
|
||||
void tree_speculative_sampling_target_only(
|
||||
at::Tensor predicts, // mutable
|
||||
at::Tensor accept_index, // mutable
|
||||
at::Tensor accept_token_num, // mutable
|
||||
at::Tensor candidates,
|
||||
at::Tensor retrive_index,
|
||||
at::Tensor retrive_next_token,
|
||||
at::Tensor retrive_next_sibling,
|
||||
at::Tensor uniform_samples,
|
||||
at::Tensor uniform_samples_for_final_sampling,
|
||||
at::Tensor target_probs,
|
||||
at::Tensor draft_probs,
|
||||
double threshold_single = 1,
|
||||
double threshold_acc = 1,
|
||||
bool deterministic = true,
|
||||
int64_t cuda_stream = 0);
|
||||
|
||||
void verify_tree_greedy(
|
||||
at::Tensor predicts, // mutable
|
||||
at::Tensor accept_index, // mutable
|
||||
at::Tensor accept_token_num, // mutable
|
||||
at::Tensor candidates,
|
||||
at::Tensor retrive_index,
|
||||
at::Tensor retrive_next_token,
|
||||
at::Tensor retrive_next_sibling,
|
||||
at::Tensor target_predict,
|
||||
int64_t cuda_stream = 0);
|
||||
|
||||
void build_tree_kernel_efficient(
|
||||
at::Tensor parent_list,
|
||||
at::Tensor selected_index,
|
||||
at::Tensor verified_seq_len,
|
||||
at::Tensor tree_mask,
|
||||
at::Tensor positions,
|
||||
at::Tensor retrive_index,
|
||||
at::Tensor retrive_next_token,
|
||||
at::Tensor retrive_next_sibling,
|
||||
int64_t topk,
|
||||
int64_t depth,
|
||||
int64_t draft_token_num,
|
||||
int64_t tree_mask_mode);
|
||||
|
||||
void segment_packbits(
|
||||
at::Tensor x,
|
||||
at::Tensor input_indptr,
|
||||
at::Tensor output_indptr,
|
||||
at::Tensor y,
|
||||
int64_t batch_size,
|
||||
int64_t cuda_stream = 0);
|
||||
|
||||
/*
|
||||
* From csrc/kvcacheio
|
||||
*/
|
||||
void transfer_kv_per_layer(
|
||||
const at::Tensor src_k,
|
||||
at::Tensor dst_k,
|
||||
const at::Tensor src_v,
|
||||
at::Tensor dst_v,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t item_size,
|
||||
int64_t block_quota,
|
||||
int64_t num_warps_per_block);
|
||||
|
||||
void transfer_kv_per_layer_pf_lf(
|
||||
const at::Tensor src_k,
|
||||
at::Tensor dst_k,
|
||||
const at::Tensor src_v,
|
||||
at::Tensor dst_v,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t layer_id,
|
||||
int64_t item_size,
|
||||
int64_t src_layout_dim,
|
||||
int64_t block_quota,
|
||||
int64_t num_warps_per_block);
|
||||
|
||||
void transfer_kv_all_layer(
|
||||
const at::Tensor src_k_layers,
|
||||
const at::Tensor dst_k_layers,
|
||||
const at::Tensor src_v_layers,
|
||||
const at::Tensor dst_v_layers,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t item_size,
|
||||
int64_t num_layers,
|
||||
int64_t block_quota,
|
||||
int64_t num_warps_per_block);
|
||||
|
||||
void transfer_kv_all_layer_lf_pf(
|
||||
const at::Tensor src_k_layers,
|
||||
at::Tensor dst_k,
|
||||
const at::Tensor src_v_layers,
|
||||
at::Tensor dst_v,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t item_size,
|
||||
int64_t dst_layout_dim,
|
||||
int64_t num_layers,
|
||||
int64_t block_quota,
|
||||
int64_t num_warps_per_block);
|
||||
|
||||
void transfer_kv_per_layer_mla(
|
||||
const at::Tensor src,
|
||||
at::Tensor dst,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t item_size,
|
||||
int64_t block_quota,
|
||||
int64_t num_warps_per_block);
|
||||
|
||||
void transfer_kv_per_layer_mla_pf_lf(
|
||||
const at::Tensor src,
|
||||
at::Tensor dst,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t layer_id,
|
||||
int64_t item_size,
|
||||
int64_t src_layout_dim,
|
||||
int64_t block_quota,
|
||||
int64_t num_warps_per_block);
|
||||
|
||||
void transfer_kv_all_layer_mla(
|
||||
const at::Tensor src_layers,
|
||||
const at::Tensor dst_layers,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t item_size,
|
||||
int64_t num_layers,
|
||||
int64_t block_quota,
|
||||
int64_t num_warps_per_block);
|
||||
|
||||
void transfer_kv_all_layer_mla_lf_pf(
|
||||
const at::Tensor src_layers,
|
||||
at::Tensor dst,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t item_size,
|
||||
int64_t dst_layout_dim,
|
||||
int64_t num_layers,
|
||||
int64_t block_quota,
|
||||
int64_t num_warps_per_block);
|
||||
|
||||
void transfer_kv_direct(
|
||||
const std::vector<at::Tensor>& src_layers,
|
||||
std::vector<at::Tensor> dst_layers,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t page_size);
|
||||
|
||||
/*
|
||||
* From FlashInfer
|
||||
*/
|
||||
void min_p_sampling_from_probs(
|
||||
at::Tensor probs,
|
||||
at::Tensor output,
|
||||
std::optional<at::Tensor> maybe_indices,
|
||||
std::optional<at::Tensor> maybe_min_p_arr,
|
||||
double min_p_val,
|
||||
bool deterministic,
|
||||
std::optional<at::Generator> gen);
|
||||
|
||||
void top_k_renorm_probs(
|
||||
at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val);
|
||||
|
||||
void top_p_renorm_probs(
|
||||
at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_p_arr, double top_p_val);
|
||||
|
||||
void top_k_top_p_sampling_from_probs(
|
||||
at::Tensor probs,
|
||||
at::Tensor output,
|
||||
std::optional<at::Tensor> maybe_indices,
|
||||
std::optional<at::Tensor> maybe_top_k_arr,
|
||||
double top_k_val,
|
||||
std::optional<at::Tensor> maybe_top_p_arr,
|
||||
double top_p_val,
|
||||
bool deterministic,
|
||||
std::optional<at::Generator> gen);
|
||||
|
||||
void top_p_sampling_from_probs(
|
||||
at::Tensor probs,
|
||||
at::Tensor output,
|
||||
std::optional<at::Tensor> maybe_indices,
|
||||
std::optional<at::Tensor> maybe_top_p_arr,
|
||||
double top_p_val,
|
||||
bool deterministic,
|
||||
std::optional<at::Generator> gen);
|
||||
|
||||
void top_k_mask_logits(
|
||||
at::Tensor logits, at::Tensor mask_logits, std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val);
|
||||
|
||||
namespace flash {
|
||||
/*
|
||||
* From fa2 sparse
|
||||
*/
|
||||
std::vector<at::Tensor> mha_fwd_sparse(
|
||||
at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size
|
||||
const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
const at::Tensor& block_count,
|
||||
const at::Tensor& block_offset,
|
||||
const at::Tensor& column_count,
|
||||
const at::Tensor& column_index,
|
||||
const std::optional<at::Tensor>& out_, // batch_size x seqlen_q x num_heads x head_size
|
||||
const std::optional<at::Tensor>& alibi_slopes_, // num_heads or batch_size x num_heads
|
||||
const double p_dropout,
|
||||
const double softmax_scale,
|
||||
bool is_causal,
|
||||
const double softcap,
|
||||
const bool return_softmax,
|
||||
std::optional<at::Generator> gen_);
|
||||
|
||||
std::vector<at::Tensor> mha_varlen_fwd_sparse(
|
||||
at::Tensor& q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor& k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i.
|
||||
const at::Tensor& v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i.
|
||||
const at::Tensor& block_count,
|
||||
const at::Tensor& block_offset,
|
||||
const at::Tensor& column_count,
|
||||
const at::Tensor& column_index,
|
||||
const c10::optional<at::Tensor>& out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor& cu_seqlens_q, // b+1
|
||||
const at::Tensor& cu_seqlens_k, // b+1
|
||||
const c10::optional<at::Tensor>&
|
||||
seqused_k, // b. If given, only this many elements of each batch element's keys are used.
|
||||
const c10::optional<at::Tensor>& alibi_slopes_, // num_heads or b x num_heads
|
||||
int64_t max_seqlen_q,
|
||||
const int64_t max_seqlen_k,
|
||||
const double p_dropout,
|
||||
const double softmax_scale,
|
||||
const bool zero_tensors,
|
||||
bool is_causal,
|
||||
const double softcap,
|
||||
const bool return_softmax,
|
||||
c10::optional<at::Generator> gen_);
|
||||
} // namespace flash
|
||||
|
||||
void convert_vertical_slash_indexes(
|
||||
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
|
||||
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
|
||||
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
|
||||
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
|
||||
torch::Tensor q_seqlens, // [BATCH, ]
|
||||
torch::Tensor kv_seqlens, // [BATCH, ]
|
||||
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
|
||||
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
|
||||
int64_t context_size,
|
||||
int64_t block_size_M,
|
||||
int64_t block_size_N,
|
||||
bool causal);
|
||||
|
||||
void convert_vertical_slash_indexes_mergehead(
|
||||
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
|
||||
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
|
||||
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
|
||||
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
|
||||
torch::Tensor q_seqlens, // [BATCH, ]
|
||||
torch::Tensor kv_seqlens, // [BATCH, ]
|
||||
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
|
||||
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
|
||||
torch::Tensor vertical_indices_count, // [N_HEADS, ]
|
||||
torch::Tensor slash_indices_count,
|
||||
int64_t context_size,
|
||||
int64_t block_size_M,
|
||||
int64_t block_size_N,
|
||||
bool causal);
|
||||
|
||||
/*
|
||||
* From XGrammar
|
||||
*/
|
||||
void ApplyTokenBitmaskInplace(at::Tensor logits, at::Tensor bitmask, at::optional<at::Tensor> indices = at::nullopt);
|
||||
|
||||
/*
|
||||
* From QServe
|
||||
*/
|
||||
void qserve_w4a8_per_chn_gemm(
|
||||
const torch::Tensor& _in_feats,
|
||||
const torch::Tensor& _kernel,
|
||||
const torch::Tensor& _wscales,
|
||||
const torch::Tensor& _ascales,
|
||||
const torch::Tensor& _w_szs,
|
||||
const torch::Tensor& _a_ssums,
|
||||
torch::Tensor& _out_feats);
|
||||
|
||||
void qserve_w4a8_per_group_gemm(
|
||||
const torch::Tensor& _in_feats,
|
||||
const torch::Tensor& _kernel,
|
||||
const torch::Tensor& _zeros,
|
||||
const torch::Tensor& _scales_i8,
|
||||
const torch::Tensor& _wscales,
|
||||
const torch::Tensor& _ascales,
|
||||
torch::Tensor& _out_feats);
|
||||
|
||||
/*
|
||||
* From csrc/spatial
|
||||
*/
|
||||
std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device);
|
||||
|
||||
/*
|
||||
* From csrc/memory
|
||||
*/
|
||||
void store_kv_cache(at::Tensor k_cache, at::Tensor v_cache, at::Tensor out_loc, at::Tensor k, at::Tensor v);
|
||||
122
sgl-kernel/include/sgl_kernel_torch_shim.h
Normal file
122
sgl-kernel/include/sgl_kernel_torch_shim.h
Normal file
@@ -0,0 +1,122 @@
|
||||
/*Adapt from:
|
||||
https://github.com/neuralmagic/vllm-flash-attention/blob/90eacc1af2a7c3de62ea249e929ed5faccf38954/csrc/common/pytorch_shim.h
|
||||
Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <torch/library.h>
|
||||
|
||||
/**
|
||||
* Unfortunately, the type signatures of the flash_attn ops are not compatible
|
||||
* with the PyTorch library bindings. To get around that we use
|
||||
* `make_pytorch_shim` which creates a lambda that exposes the API using
|
||||
* PyTorch compatible types to the types, then converts them to the types
|
||||
* expected by the flash_attn ops. This shims allows us to make minimal changes
|
||||
* to `flash_api.cpp` making it easier to synchronize with upstream changes.
|
||||
*
|
||||
* The `pytorch_library_compatible_type` struct is used to map from the
|
||||
* flash_attn ops types to a PyTorch library compatible one. The main issues is
|
||||
* that the following types are not support by PyTorch library bindings:
|
||||
* - `int`
|
||||
* - `float`
|
||||
* - `std::optional<T> &`
|
||||
* - `std::optional<const at::Tensor> &`
|
||||
* So we convert them to (respectively):
|
||||
* - `int64_t`
|
||||
* - `double`
|
||||
* - `const std::optional<T>&`
|
||||
* - `const std::optional<at::Tensor>&`
|
||||
*/
|
||||
|
||||
template <typename T>
|
||||
struct pytorch_library_compatible_type {
|
||||
using type = T;
|
||||
static T convert_from_type(T arg) {
|
||||
return arg;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using pytorch_library_compatible_type_t = typename pytorch_library_compatible_type<T>::type;
|
||||
|
||||
template <typename T>
|
||||
T convert_from_pytorch_compatible_type(pytorch_library_compatible_type_t<T> arg) {
|
||||
return pytorch_library_compatible_type<T>::convert_from_type(arg);
|
||||
}
|
||||
|
||||
// Map `c10::optional<T> &` -> `const c10::optional<T>&`
|
||||
// (NOTE: this is bit unsafe but non of the ops in flash_attn mutate
|
||||
// the optional container)
|
||||
template <typename T>
|
||||
struct pytorch_library_compatible_type<c10::optional<T>&> {
|
||||
using type = const c10::optional<T>&;
|
||||
static c10::optional<T>& convert_from_type(const c10::optional<T>& arg) {
|
||||
return const_cast<c10::optional<T>&>(arg);
|
||||
}
|
||||
};
|
||||
|
||||
// Map `c10::optional<T>` ->
|
||||
// `c10::optional<pytorch_library_compatible_type_t<T>>`
|
||||
// (NOTE: tested for `c10::optional<int>` -> `c10::optional<int64_t>`)
|
||||
template <typename T>
|
||||
struct pytorch_library_compatible_type<c10::optional<T>> {
|
||||
using type = c10::optional<pytorch_library_compatible_type_t<T>>;
|
||||
static c10::optional<pytorch_library_compatible_type_t<T>> convert_from_type(c10::optional<T> arg) {
|
||||
return arg;
|
||||
}
|
||||
};
|
||||
|
||||
// Map `c10::optional<const at::Tensor>&` -> `const c10::optional<at::Tensor>&`
|
||||
template <>
|
||||
struct pytorch_library_compatible_type<c10::optional<const at::Tensor>&> {
|
||||
using type = const c10::optional<at::Tensor>&;
|
||||
static c10::optional<const at::Tensor>& convert_from_type(const c10::optional<at::Tensor>& arg) {
|
||||
return const_cast<c10::optional<const at::Tensor>&>(reinterpret_cast<const c10::optional<const at::Tensor>&>(arg));
|
||||
}
|
||||
};
|
||||
|
||||
// Map `int` -> `int64_t`
|
||||
template <>
|
||||
struct pytorch_library_compatible_type<int> {
|
||||
using type = int64_t;
|
||||
static int convert_from_type(int64_t arg) {
|
||||
TORCH_CHECK(arg <= std::numeric_limits<int>::max(), "int64_t value is too large to be converted to int");
|
||||
TORCH_CHECK(arg >= std::numeric_limits<int>::min(), "int64_t value is too small to be converted to int");
|
||||
return arg;
|
||||
}
|
||||
};
|
||||
|
||||
// Map `float` -> `double`
|
||||
template <>
|
||||
struct pytorch_library_compatible_type<float> {
|
||||
using type = double;
|
||||
static float convert_from_type(double arg) {
|
||||
TORCH_CHECK(
|
||||
std::abs(arg) <= std::numeric_limits<float>::max(), "double value is too large to be converted to float");
|
||||
return arg;
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Shim Utils
|
||||
//
|
||||
|
||||
template <typename Ret, typename... Args>
|
||||
auto make_pytorch_shim(Ret (*fun)(Args... args)) {
|
||||
return [fun](pytorch_library_compatible_type_t<Args>... args) {
|
||||
return fun(convert_from_pytorch_compatible_type<Args>(args)...);
|
||||
};
|
||||
}
|
||||
449
sgl-kernel/include/utils.h
Normal file
449
sgl-kernel/include/utils.h
Normal file
@@ -0,0 +1,449 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491)
|
||||
#define _DISPATCH_CASE_F16(c_type, ...) \
|
||||
case at::ScalarType::Half: { \
|
||||
using c_type = __half; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
|
||||
#define _DISPATCH_CASE_BF16(c_type, ...) \
|
||||
case at::ScalarType::BFloat16: { \
|
||||
using c_type = __hip_bfloat16; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
#endif // USE_ROCM
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Adapt from FlashInfer
|
||||
#ifdef FLASHINFER_ENABLE_F16
|
||||
#define _DISPATCH_CASE_F16(c_type, ...) \
|
||||
case at::ScalarType::Half: { \
|
||||
using c_type = nv_half; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
#else
|
||||
#define _DISPATCH_CASE_F16(c_type, ...)
|
||||
#endif // FLASHINFER_ENABLE_F16
|
||||
|
||||
#ifdef FLASHINFER_ENABLE_BF16
|
||||
#define _DISPATCH_CASE_BF16(c_type, ...) \
|
||||
case at::ScalarType::BFloat16: { \
|
||||
using c_type = nv_bfloat16; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
#else
|
||||
#define _DISPATCH_CASE_BF16(c_type, ...)
|
||||
#endif // FLASHINFER_ENABLE_BF16
|
||||
|
||||
#ifdef FLASHINFER_ENABLE_FP8_E4M3
|
||||
#define _DISPATCH_CASE_FP8_E4M3(c_type, ...) \
|
||||
case at::ScalarType::Float8_e4m3fn: { \
|
||||
using c_type = __nv_fp8_e4m3; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
#else
|
||||
#define _DISPATCH_CASE_FP8_E4M3(c_type, ...)
|
||||
#endif // FLASHINFER_ENABLE_FP8_E4M3
|
||||
|
||||
#ifdef FLASHINFER_ENABLE_FP8_E5M2
|
||||
#define _DISPATCH_CASE_FP8_E5M2(c_type, ...) \
|
||||
case at::ScalarType::Float8_e5m2: { \
|
||||
using c_type = __nv_fp8_e5m2; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
#else
|
||||
#define _DISPATCH_CASE_FP8_E5M2(c_type, ...)
|
||||
#endif // FLASHINFER_ENABLE_FP8_E5M2
|
||||
|
||||
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \
|
||||
[&]() -> bool { \
|
||||
switch (pytorch_dtype) { \
|
||||
_DISPATCH_CASE_F16(c_type, __VA_ARGS__) \
|
||||
_DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \
|
||||
default: \
|
||||
std::ostringstream oss; \
|
||||
oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \
|
||||
TORCH_CHECK(false, oss.str()); \
|
||||
return false; \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(pytorch_dtype, c_type, ...) \
|
||||
[&]() -> bool { \
|
||||
switch (pytorch_dtype) { \
|
||||
_DISPATCH_CASE_FP8_E4M3(c_type, __VA_ARGS__) \
|
||||
_DISPATCH_CASE_FP8_E5M2(c_type, __VA_ARGS__) \
|
||||
default: \
|
||||
std::ostringstream oss; \
|
||||
oss << __PRETTY_FUNCTION__ << " failed to dispatch fp8 data type " << pytorch_dtype; \
|
||||
TORCH_CHECK(false, oss.str()); \
|
||||
return false; \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \
|
||||
[&]() -> bool { \
|
||||
switch (pytorch_dtype) { \
|
||||
_DISPATCH_CASE_F16(c_type, __VA_ARGS__) \
|
||||
_DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \
|
||||
_DISPATCH_CASE_FP8_E4M3(c_type, __VA_ARGS__) \
|
||||
_DISPATCH_CASE_FP8_E5M2(c_type, __VA_ARGS__) \
|
||||
default: \
|
||||
std::ostringstream oss; \
|
||||
oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \
|
||||
TORCH_CHECK(false, oss.str()); \
|
||||
return false; \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define _DISPATCH_SWITCH(var_name, cond, ...) \
|
||||
[&]() -> bool { \
|
||||
switch (cond) { \
|
||||
__VA_ARGS__ \
|
||||
default: \
|
||||
std::ostringstream oss; \
|
||||
oss << __PRETTY_FUNCTION__ << " failed to dispatch " var_name " " << int(cond); \
|
||||
TORCH_CHECK(false, oss.str()); \
|
||||
return false; \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define _DISPATCH_SWITCH_U16x2(var1_name, var2_name, cond1, cond2, ...) \
|
||||
[&]() -> bool { \
|
||||
switch (pack_u16(cond1, cond2)) { \
|
||||
__VA_ARGS__ \
|
||||
default: \
|
||||
std::ostringstream oss; \
|
||||
oss << __PRETTY_FUNCTION__ << " failed to dispatch (" var1_name ", " var2_name "): (" << int(cond1) << ", " \
|
||||
<< int(cond2) << ")"; \
|
||||
TORCH_CHECK(false, oss.str()); \
|
||||
return false; \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define _DISPATCH_CASE(case_expr, case_var, ...) \
|
||||
case case_expr: { \
|
||||
constexpr auto case_var = case_expr; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
|
||||
#define _DISPATCH_CASE_U16x2(case_expr1, case_expr2, case_var1, case_var2, ...) \
|
||||
case pack_u16(case_expr1, case_expr2): { \
|
||||
constexpr auto case_var1 = case_expr1; \
|
||||
constexpr auto case_var2 = case_expr2; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
|
||||
#define DISPATCH_BOOL(expr, const_expr, ...) \
|
||||
[&]() -> bool { \
|
||||
if (expr) { \
|
||||
constexpr bool const_expr = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
constexpr bool const_expr = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
|
||||
inline void check_shape(const at::Tensor& a, const at::Tensor& b, const char* a_name, const char* b_name) {
|
||||
TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", a.dim(), " vs ", b.dim());
|
||||
for (int i = 0; i < a.dim(); ++i) {
|
||||
TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name, ".size(", i, ")");
|
||||
}
|
||||
}
|
||||
|
||||
inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
|
||||
return (uint32_t(a) << 16) | uint32_t(b);
|
||||
}
|
||||
|
||||
#define CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads) \
|
||||
TORCH_CHECK( \
|
||||
num_qo_heads % num_kv_heads == 0, \
|
||||
"num_qo_heads(", \
|
||||
num_qo_heads, \
|
||||
") must be divisible by num_kv_heads(", \
|
||||
num_kv_heads, \
|
||||
")")
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
||||
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_LAST_DIM_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimension")
|
||||
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_LAST_DIM_CONTIGUOUS(x)
|
||||
|
||||
#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
|
||||
|
||||
#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b)
|
||||
|
||||
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
|
||||
|
||||
#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b)
|
||||
|
||||
inline bool is_float8_tensor(const at::Tensor& tensor) {
|
||||
return tensor.scalar_type() == at::ScalarType::Float8_e4m3fn || tensor.scalar_type() == at::ScalarType::Float8_e5m2;
|
||||
}
|
||||
#endif // USE_ROCM
|
||||
|
||||
struct cuda_error : public std::runtime_error {
|
||||
/**
|
||||
* @brief Constructs a `cuda_error` object with the given `message`.
|
||||
*
|
||||
* @param message The error char array used to construct `cuda_error`
|
||||
*/
|
||||
cuda_error(const char* message) : std::runtime_error(message) {}
|
||||
/**
|
||||
* @brief Constructs a `cuda_error` object with the given `message` string.
|
||||
*
|
||||
* @param message The `std::string` used to construct `cuda_error`
|
||||
*/
|
||||
cuda_error(std::string const& message) : cuda_error{message.c_str()} {}
|
||||
};
|
||||
|
||||
#define CHECK_CUDA_SUCCESS(cmd) \
|
||||
do { \
|
||||
cudaError_t e = cmd; \
|
||||
if (e != cudaSuccess) { \
|
||||
std::stringstream _message; \
|
||||
auto s = cudaGetErrorString(e); \
|
||||
_message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__; \
|
||||
throw cuda_error(_message.str()); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define CHECK_IS_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_IS_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_CUDA_INPUT(x) \
|
||||
CHECK_IS_CUDA(x); \
|
||||
CHECK_IS_CONTIGUOUS(x)
|
||||
|
||||
inline int getSMVersion() {
|
||||
int device{-1};
|
||||
CHECK_CUDA_SUCCESS(cudaGetDevice(&device));
|
||||
int sm_major = 0;
|
||||
int sm_minor = 0;
|
||||
CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device));
|
||||
CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device));
|
||||
return sm_major * 10 + sm_minor;
|
||||
}
|
||||
|
||||
inline bool isDeviceType(const std::string& device_type) {
|
||||
int deviceCount;
|
||||
CHECK_CUDA_SUCCESS(cudaGetDeviceCount(&deviceCount));
|
||||
|
||||
int device_id = -1;
|
||||
if (deviceCount >= 1) {
|
||||
CHECK_CUDA_SUCCESS(cudaGetDevice(&device_id));
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
cudaDeviceProp prop;
|
||||
CHECK_CUDA_SUCCESS(cudaGetDeviceProperties(&prop, device_id));
|
||||
if (device_type == std::string(prop.name)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool getBoolEnv(char const* name) {
|
||||
char const* env = std::getenv(name);
|
||||
return env && env[0] == '1' && env[1] == '\0';
|
||||
}
|
||||
|
||||
inline bool getEnvEnablePDL() {
|
||||
static std::once_flag flag;
|
||||
static bool enablePDL = false;
|
||||
std::call_once(flag, [&]() {
|
||||
if (getSMVersion() >= 90) {
|
||||
// PDL will be enabled by setting the env variables `TRTLLM_ENABLE_PDL` to `1`
|
||||
enablePDL = getBoolEnv("TRTLLM_ENABLE_PDL");
|
||||
}
|
||||
});
|
||||
return enablePDL;
|
||||
}
|
||||
|
||||
// SGLANG_SHFL_XOR_* adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/csrc/cuda_compat.h#L19-L28
|
||||
#ifndef USE_ROCM
|
||||
#define SGLANG_SHFL_XOR_SYNC(mask, var, lane_mask) __shfl_xor_sync((mask), (var), (lane_mask))
|
||||
#define SGLANG_SHFL_XOR_SYNC_WIDTH(mask, var, lane_mask, width) __shfl_xor_sync((mask), (var), (lane_mask), (width))
|
||||
#else
|
||||
#define SGLANG_SHFL_XOR_SYNC(mask, var, lane_mask) __shfl_xor((var), (lane_mask))
|
||||
#define SGLANG_SHFL_XOR_SYNC_WIDTH(mask, var, lane_mask, width) __shfl_xor((var), (lane_mask), (width))
|
||||
#endif
|
||||
|
||||
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(pytorch_dtype, c_type, ...) \
|
||||
[&]() -> bool { \
|
||||
switch (pytorch_dtype) { \
|
||||
case at::ScalarType::Float: { \
|
||||
using c_type = float; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
_DISPATCH_CASE_F16(c_type, __VA_ARGS__) \
|
||||
_DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \
|
||||
default: \
|
||||
std::ostringstream oss; \
|
||||
oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \
|
||||
TORCH_CHECK(false, oss.str()); \
|
||||
return false; \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define DISPATCH_CASE_INTEGRAL_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
||||
|
||||
#define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
||||
|
||||
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 32
|
||||
#else
|
||||
#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
|
||||
#define WARP_SIZE 64
|
||||
#else
|
||||
#define WARP_SIZE 32
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef USE_ROCM
|
||||
|
||||
#include "hip/hip_math_def.h"
|
||||
#include "hip/hip_vec_dtypes.h"
|
||||
|
||||
#else
|
||||
|
||||
template <typename srcDtype>
|
||||
__device__ __forceinline__ float castToFloat(srcDtype val) {
|
||||
return static_cast<srcDtype>(val);
|
||||
}
|
||||
|
||||
template <typename dstDtype>
|
||||
__device__ __forceinline__ dstDtype castFromFloat(float val) {
|
||||
return static_cast<dstDtype>(val);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// add FP8 support
|
||||
// #ifndef USE_ROCM
|
||||
// #include <c10/util/Float8_e4m3fn.h>
|
||||
// using FP8_TYPE = c10::Float8_e4m3fn;
|
||||
// C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
|
||||
// #else // USE_ROCM
|
||||
// #if HIP_FP8_TYPE_FNUZ
|
||||
// #include <c10/util/Float8_e4m3fnuz.h>
|
||||
// using FP8_TYPE = c10::Float8_e4m3fnuz;
|
||||
// constexpr auto FP8_E4M3_MAX = 224.0f;
|
||||
// #else
|
||||
// #if HIP_FP8_TYPE_E4M3
|
||||
// #include <c10/util/Float8_e4m3fn.h>
|
||||
// using FP8_TYPE = c10::Float8_e4m3fn;
|
||||
// C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
|
||||
// #else
|
||||
// #error "fp8 is not supported in this processor (arch < gfx942)."
|
||||
// #endif // HIP_FP8_TYPE_E4M3
|
||||
// #endif // HIP_FP8_TYPE_FNUZ
|
||||
// #endif // USE_ROCM
|
||||
|
||||
#define FULL_MASK 0xffffffff
|
||||
|
||||
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||
#ifndef USE_ROCM
|
||||
float old;
|
||||
old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
|
||||
: __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
|
||||
return old;
|
||||
#else
|
||||
int* addr_as_i = (int*)addr;
|
||||
int old = *addr_as_i, assumed;
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(addr_as_i, assumed, __float_as_int(fmaxf(value, __int_as_float(assumed))));
|
||||
} while (assumed != old);
|
||||
return __int_as_float(old);
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float warpReduceMax(float value) {
|
||||
value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 16));
|
||||
value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 8));
|
||||
value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 4));
|
||||
value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 2));
|
||||
value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 1));
|
||||
return value;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float blockReduceMax(float value) {
|
||||
static __shared__ float warpLevelMaxs[WARP_SIZE];
|
||||
const int laneId = threadIdx.x % WARP_SIZE;
|
||||
const int warpId = threadIdx.x / WARP_SIZE;
|
||||
|
||||
value = warpReduceMax(value);
|
||||
|
||||
if (laneId == 0) warpLevelMaxs[warpId] = value;
|
||||
__syncthreads();
|
||||
|
||||
value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
|
||||
if (warpId == 0) value = warpReduceMax(value);
|
||||
|
||||
return value;
|
||||
}
|
||||
|
||||
// Pads to a multiple of `alignment` rows.
|
||||
inline torch::Tensor pad_tensor(const torch::Tensor& tensor, int64_t alignment = 4, bool is_column_major = false) {
|
||||
int64_t rows = tensor.size(0);
|
||||
int64_t cols = tensor.size(1);
|
||||
int64_t pad_rows = (alignment - (rows % alignment)) % alignment; // Compute padding size
|
||||
|
||||
if (pad_rows == 0) {
|
||||
return tensor; // Already aligned
|
||||
}
|
||||
|
||||
torch::Tensor padding = torch::zeros({pad_rows, cols}, tensor.options());
|
||||
torch::Tensor tensor_padded = torch::cat({tensor, padding}, 0); // Pad along rows
|
||||
|
||||
// Ensure column-major layout
|
||||
if (is_column_major) {
|
||||
return tensor_padded.t().contiguous().t();
|
||||
}
|
||||
return tensor_padded;
|
||||
}
|
||||
|
||||
// Get the next power of 2 of a number
|
||||
inline uint32_t next_pow2(uint32_t x) noexcept {
|
||||
if (x <= 1) return 1;
|
||||
return 1u << (32 - __builtin_clz(x - 1));
|
||||
}
|
||||
Reference in New Issue
Block a user