sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct

This commit is contained in:
maxiao1
2025-09-13 17:00:20 +08:00
commit 118f1fc726
2037 changed files with 515371 additions and 0 deletions

View File

@@ -0,0 +1,170 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#ifndef USE_ROCM
#include <flashinfer/activation.cuh>
#include "utils.h"
#else
#include "hip/hip_act_and_mul.cuh"
#endif
// Adapted from flashinfer activation
// https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/csrc/activation.cu#L44
namespace detail {
template <typename T>
__device__ __forceinline__ float to_f32(const T& x) {
#if USE_ROCM
return castToFloat(x);
#else
return static_cast<float>(x);
#endif
}
template <typename T>
__device__ __forceinline__ T from_f32(float f32) {
#if USE_ROCM
return castFromFloat<T>(f32);
#else
return static_cast<T>(f32);
#endif
}
} // namespace detail
template <typename T>
__device__ __forceinline__ T silu(const T& x) {
float f32_val = detail::to_f32(x);
return detail::from_f32<T>(f32_val / (1.0f + expf(-f32_val)));
}
template <typename T>
__device__ __forceinline__ T gelu(const T& x) {
constexpr float kAlpha = M_SQRT1_2;
float f32_val = detail::to_f32(x);
return detail::from_f32<T>(f32_val * (0.5f * (1.0f + erf(f32_val * kAlpha))));
}
// gelu_quick(x) = x * torch.sigmoid(1.702 * x)
template <typename T>
__device__ __forceinline__ T gelu_quick_act(const T& x) {
float f32_val = detail::to_f32(x);
return detail::from_f32<T>(f32_val / (1.0f + expf(-f32_val * 1.702f)));
}
template <typename T>
__device__ __forceinline__ T gelu_tanh(const T& x) {
constexpr float kAlpha = 0.044715f;
constexpr float kBeta = 0.7978845608028654f;
float f32_val = detail::to_f32(x);
const float cdf = 0.5f * (1.0f + tanhf((kBeta * (f32_val + kAlpha * f32_val * f32_val * f32_val))));
return detail::from_f32<T>(f32_val * cdf);
}
void silu_and_mul(at::Tensor& out, at::Tensor& input) {
int d = input.size(-1) / 2;
int64_t num_tokens = input.numel() / input.size(-1);
dim3 grid(num_tokens);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
uint32_t vec_size = 16 / sizeof(c_type);
dim3 block(std::min(d / vec_size, 1024U));
#if USE_ROCM
sgl_hip::activation::act_and_mul_kernel<c_type, silu>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
#else
flashinfer::activation::act_and_mul_kernel<c_type, silu>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
#endif
return true;
});
}
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input) {
int d = input.size(-1) / 2;
int64_t num_tokens = input.numel() / input.size(-1);
dim3 grid(num_tokens);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
uint32_t vec_size = 16 / sizeof(c_type);
dim3 block(std::min(d / vec_size, 1024U));
#if USE_ROCM
sgl_hip::activation::act_and_mul_kernel<c_type, gelu_tanh>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
#else
flashinfer::activation::act_and_mul_kernel<c_type, gelu_tanh>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
#endif
return true;
});
}
void gelu_and_mul(at::Tensor& out, at::Tensor& input) {
int d = input.size(-1) / 2;
int64_t num_tokens = input.numel() / input.size(-1);
dim3 grid(num_tokens);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
uint32_t vec_size = 16 / sizeof(c_type);
dim3 block(std::min(d / vec_size, 1024U));
#if USE_ROCM
sgl_hip::activation::act_and_mul_kernel<c_type, gelu>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
#else
flashinfer::activation::act_and_mul_kernel<c_type, gelu>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
#endif
return true;
});
}
#if USE_ROCM
void gelu_quick(at::Tensor& out, const at::Tensor& input) {
int d = input.size(-1);
int64_t num_tokens = input.numel() / input.size(-1);
dim3 grid(num_tokens);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
uint32_t vec_size = 16 / sizeof(c_type);
dim3 block(std::min(d / vec_size, 1024U));
sgl_hip::activation::act_only_kernel<c_type, gelu_quick_act>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
return true;
});
}
#endif

View File

@@ -0,0 +1,173 @@
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include "hip/hip_runtime.h"
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/hip/HIPContext.h>
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
#include <torch/all.h>
#ifndef USE_ROCM
#include <flashinfer/activation.cuh>
#include "utils_hip.h"
#else
#include "hip/hip_act_and_mul_hip.cuh"
#endif
// Adapted from flashinfer activation
// https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/csrc/activation.cu#L44
namespace detail {
template <typename T>
__device__ __forceinline__ float to_f32(const T& x) {
#if USE_ROCM
return castToFloat(x);
#else
return static_cast<float>(x);
#endif
}
template <typename T>
__device__ __forceinline__ T from_f32(float f32) {
#if USE_ROCM
return castFromFloat<T>(f32);
#else
return static_cast<T>(f32);
#endif
}
} // namespace detail
template <typename T>
__device__ __forceinline__ T silu(const T& x) {
float f32_val = detail::to_f32(x);
return detail::from_f32<T>(f32_val / (1.0f + expf(-f32_val)));
}
template <typename T>
__device__ __forceinline__ T gelu(const T& x) {
constexpr float kAlpha = M_SQRT1_2;
float f32_val = detail::to_f32(x);
return detail::from_f32<T>(f32_val * (0.5f * (1.0f + erf(f32_val * kAlpha))));
}
// gelu_quick(x) = x * torch.sigmoid(1.702 * x)
template <typename T>
__device__ __forceinline__ T gelu_quick_act(const T& x) {
float f32_val = detail::to_f32(x);
return detail::from_f32<T>(f32_val / (1.0f + expf(-f32_val * 1.702f)));
}
template <typename T>
__device__ __forceinline__ T gelu_tanh(const T& x) {
constexpr float kAlpha = 0.044715f;
constexpr float kBeta = 0.7978845608028654f;
float f32_val = detail::to_f32(x);
const float cdf = 0.5f * (1.0f + tanhf((kBeta * (f32_val + kAlpha * f32_val * f32_val * f32_val))));
return detail::from_f32<T>(f32_val * cdf);
}
void silu_and_mul(at::Tensor& out, at::Tensor& input) {
int d = input.size(-1) / 2;
int64_t num_tokens = input.numel() / input.size(-1);
dim3 grid(num_tokens);
const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input));
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
uint32_t vec_size = 16 / sizeof(c_type);
dim3 block(::min(d / vec_size, 1024U));
#if USE_ROCM
hipLaunchKernelGGL(( sgl_hip::activation::act_and_mul_kernel<c_type, silu>)
, dim3(grid), dim3(block), 0, stream, static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
#else
hipLaunchKernelGGL(( flashinfer::activation::act_and_mul_kernel<c_type, silu>)
, dim3(grid), dim3(block), 0, stream, static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
#endif
return true;
});
}
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input) {
int d = input.size(-1) / 2;
int64_t num_tokens = input.numel() / input.size(-1);
dim3 grid(num_tokens);
const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input));
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
uint32_t vec_size = 16 / sizeof(c_type);
dim3 block(::min(d / vec_size, 1024U));
#if USE_ROCM
hipLaunchKernelGGL(( sgl_hip::activation::act_and_mul_kernel<c_type, gelu_tanh>)
, dim3(grid), dim3(block), 0, stream, static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
#else
hipLaunchKernelGGL(( flashinfer::activation::act_and_mul_kernel<c_type, gelu_tanh>)
, dim3(grid), dim3(block), 0, stream, static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
#endif
return true;
});
}
void gelu_and_mul(at::Tensor& out, at::Tensor& input) {
int d = input.size(-1) / 2;
int64_t num_tokens = input.numel() / input.size(-1);
dim3 grid(num_tokens);
const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input));
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
uint32_t vec_size = 16 / sizeof(c_type);
dim3 block(::min(d / vec_size, 1024U));
#if USE_ROCM
hipLaunchKernelGGL(( sgl_hip::activation::act_and_mul_kernel<c_type, gelu>)
, dim3(grid), dim3(block), 0, stream, static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
#else
hipLaunchKernelGGL(( flashinfer::activation::act_and_mul_kernel<c_type, gelu>)
, dim3(grid), dim3(block), 0, stream, static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
#endif
return true;
});
}
#if USE_ROCM
void gelu_quick(at::Tensor& out, const at::Tensor& input) {
int d = input.size(-1);
int64_t num_tokens = input.numel() / input.size(-1);
dim3 grid(num_tokens);
const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input));
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
uint32_t vec_size = 16 / sizeof(c_type);
dim3 block(::min(d / vec_size, 1024U));
hipLaunchKernelGGL(( sgl_hip::activation::act_only_kernel<c_type, gelu_quick_act>)
, dim3(grid), dim3(block), 0, stream, static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
return true;
});
}
#endif

View File

@@ -0,0 +1,171 @@
#include "pytorch_extension_utils.h"
template <typename T>
struct ConvertToFP8 {
static __device__ __nv_fp8_storage_t convert_to_fp8(T value) {
return 0;
}
};
template <>
struct ConvertToFP8<__nv_bfloat16> {
static __device__ __nv_fp8_storage_t convert_to_fp8(__nv_bfloat16 value) {
return __nv_cvt_bfloat16raw_to_fp8(value, __NV_SATFINITE, __NV_E4M3);
}
};
template <>
struct ConvertToFP8<half> {
static __device__ __nv_fp8_storage_t convert_to_fp8(half value) {
return __nv_cvt_halfraw_to_fp8(value, __NV_SATFINITE, __NV_E4M3);
}
};
template <typename T>
struct ConvertFromFloat {
static __device__ T convert_from_float(float value) {
return 0;
}
};
template <>
struct ConvertFromFloat<__nv_bfloat16> {
static __device__ __nv_bfloat16 convert_from_float(float value) {
return __float2bfloat16(value);
}
};
template <>
struct ConvertFromFloat<half> {
static __device__ half convert_from_float(float value) {
return __float2half(value);
}
};
template <typename T>
__global__ void fused_downcast_kernel(
const T* cache_k,
const T* cache_v,
const float* k_scale,
const float* v_scale,
__nv_fp8_storage_t* output_k,
__nv_fp8_storage_t* output_v,
const int input_sl,
const int head,
const int dim,
const T max_fp8,
const T min_fp8,
const int64_t mult,
const int64_t offset,
const int64_t* loc) {
// TODO: change name
int token_idx = blockIdx.x;
int thread_idx = threadIdx.x;
int total_threads = blockDim.x;
T k_scale_val = ConvertFromFloat<T>::convert_from_float(k_scale[0]);
T v_scale_val = ConvertFromFloat<T>::convert_from_float(v_scale[0]);
T k_scale_inv = static_cast<T>(1.f) / k_scale_val;
T v_scale_inv = static_cast<T>(1.f) / v_scale_val;
auto clamp = [&](T val) { return val > max_fp8 ? max_fp8 : (min_fp8 > val ? min_fp8 : val); };
if (token_idx < input_sl) {
int out_seq_idx = loc[token_idx];
#pragma unroll
for (int i = thread_idx; i < head * dim; i += total_threads) {
int in_idx = token_idx * head * dim + i;
int out_idx = (out_seq_idx * mult + offset) * head * dim + i;
T k_val = cache_k[in_idx] * k_scale_inv;
k_val = clamp(k_val);
output_k[out_idx] = ConvertToFP8<T>::convert_to_fp8(k_val);
T v_val = cache_v[in_idx] * v_scale_inv;
v_val = clamp(v_val);
output_v[out_idx] = ConvertToFP8<T>::convert_to_fp8(v_val);
}
}
}
template <typename T>
void downcast_fp8_impl(
at::Tensor& k,
at::Tensor& v,
at::Tensor& k_out,
at::Tensor& v_out,
at::Tensor& k_scale,
at::Tensor& v_scale,
at::Tensor& loc,
int64_t mult,
int64_t offset,
cudaStream_t stream) {
CHECK_INPUT(k);
CHECK_INPUT(v);
CHECK_INPUT(k_out);
CHECK_INPUT(v_out);
CHECK_INPUT(k_scale);
CHECK_INPUT(v_scale);
CHECK_INPUT(loc);
int64_t input_sl = k.size(0);
int64_t head = k.size(1);
int64_t dim = k.size(2);
dim3 grid(input_sl * head);
int vec_size = 8;
dim3 block(std::min(int(dim) / vec_size, 1024));
const T max_fp8 = static_cast<T>(448.0f);
const T min_fp8 = static_cast<T>(-448.0f);
fused_downcast_kernel<T><<<grid, block, 0, stream>>>(
static_cast<const T*>(k.data_ptr()),
static_cast<const T*>(v.data_ptr()),
static_cast<const float*>(k_scale.data_ptr()),
static_cast<const float*>(v_scale.data_ptr()),
static_cast<__nv_fp8_storage_t*>(k_out.data_ptr()),
static_cast<__nv_fp8_storage_t*>(v_out.data_ptr()),
input_sl,
head,
dim,
max_fp8,
min_fp8,
mult,
offset,
static_cast<const int64_t*>(loc.data_ptr()));
cudaError_t status = cudaGetLastError();
TORCH_CHECK(status == cudaSuccess, "Kernel launch failed: " + std::string(cudaGetErrorString(status)));
}
void downcast_fp8(
at::Tensor& k,
at::Tensor& v,
at::Tensor& k_out,
at::Tensor& v_out,
at::Tensor& k_scale,
at::Tensor& v_scale,
at::Tensor& loc,
int64_t mult,
int64_t offset,
int64_t cuda_stream) {
CHECK_INPUT(k);
CHECK_INPUT(v);
CHECK_INPUT(k_out);
CHECK_INPUT(v_out);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
switch (k.scalar_type()) {
case at::ScalarType::BFloat16:
downcast_fp8_impl<__nv_bfloat16>(k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, stream);
break;
case at::ScalarType::Half:
downcast_fp8_impl<__half>(k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, stream);
break;
default:
TORCH_CHECK(false, "Unsupported input type for downcast_fp8. Expected bfloat16 or float16.");
}
}

View File

@@ -0,0 +1,117 @@
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDADataType.h>
#include <cuda_runtime.h>
#include "pytorch_extension_utils.h"
constexpr int NUM_LOCAL_HEADS = 128;
constexpr int QK_NOPE_HEAD_DIM = 128;
constexpr int QK_ROPE_HEAD_DIM = 64;
constexpr int K_HEAD_DIM = QK_NOPE_HEAD_DIM + QK_ROPE_HEAD_DIM;
constexpr int HEAD_CHUNK_SIZE = 16;
constexpr int NUM_HEAD_CHUNKS = NUM_LOCAL_HEADS / HEAD_CHUNK_SIZE;
__forceinline__ __device__ int get_lane_id() {
int lane_id;
asm("mov.s32 %0, %laneid;" : "=r"(lane_id));
return lane_id;
}
int ceil_div(int a, int b) {
return (a + b - 1) / b;
}
__global__ void concat_mla_k_kernel(
nv_bfloat16* k,
nv_bfloat16* k_nope,
nv_bfloat16* k_rope,
const int num_tokens,
const int k_stride_0,
const int k_stride_1,
const int k_nope_stride_0,
const int k_nope_stride_1,
const int k_rope_stride_0) {
const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
const int token_id = flat_warp_id / NUM_HEAD_CHUNKS;
const int head_chunk_id = flat_warp_id % NUM_HEAD_CHUNKS;
const int lane_id = get_lane_id();
if (token_id >= num_tokens) {
return;
}
using KNopeBufType = int2;
static_assert(sizeof(KNopeBufType) == QK_NOPE_HEAD_DIM * sizeof(k[0]) / 32);
KNopeBufType k_nope_buf[HEAD_CHUNK_SIZE];
using KRopeBufType = int;
static_assert(sizeof(KRopeBufType) == QK_ROPE_HEAD_DIM * sizeof(k[0]) / 32);
KRopeBufType k_rope_buf;
{
const int* base_addr = reinterpret_cast<int*>(k_rope + token_id * k_rope_stride_0);
k_rope_buf = *(base_addr + lane_id);
}
#pragma unroll
for (int i = 0; i < HEAD_CHUNK_SIZE; ++i) {
const int head_id = head_chunk_id * HEAD_CHUNK_SIZE + i;
const int2* base_addr = reinterpret_cast<int2*>(k_nope + token_id * k_nope_stride_0 + head_id * k_nope_stride_1);
k_nope_buf[i] = *(base_addr + lane_id);
}
#pragma unroll
for (int i = 0; i < HEAD_CHUNK_SIZE; ++i) {
const int head_id = head_chunk_id * HEAD_CHUNK_SIZE + i;
{
int2* base_addr = reinterpret_cast<int2*>(k + token_id * k_stride_0 + head_id * k_stride_1);
*(base_addr + lane_id) = k_nope_buf[i];
}
{
int* base_addr = reinterpret_cast<int*>(k + token_id * k_stride_0 + head_id * k_stride_1 + QK_NOPE_HEAD_DIM);
*(base_addr + lane_id) = k_rope_buf;
}
}
}
inline void check_tensor(const at::Tensor& t, int64_t shape0, int64_t shape1, int64_t shape2, c10::ScalarType dtype) {
TORCH_CHECK_EQ(t.dim(), 3);
TORCH_CHECK_EQ(t.size(0), shape0);
TORCH_CHECK_EQ(t.size(1), shape1);
TORCH_CHECK_EQ(t.size(2), shape2);
TORCH_CHECK_EQ(t.dtype(), dtype);
TORCH_CHECK(t.device().is_cuda());
TORCH_CHECK_EQ(((int64_t)t.data_ptr()) % 16, 0); // alignment
}
void concat_mla_k(at::Tensor k, at::Tensor k_nope, at::Tensor k_rope) {
const int num_tokens = k.size(0);
check_tensor(k, num_tokens, NUM_LOCAL_HEADS, K_HEAD_DIM, at::kBFloat16);
check_tensor(k_nope, num_tokens, NUM_LOCAL_HEADS, QK_NOPE_HEAD_DIM, at::kBFloat16);
check_tensor(k_rope, num_tokens, 1, QK_ROPE_HEAD_DIM, at::kBFloat16);
TORCH_CHECK_EQ(k.stride(2), 1);
TORCH_CHECK_EQ(k_nope.stride(2), 1);
TORCH_CHECK_EQ(k_rope.stride(2), 1);
const auto stream = at::cuda::getCurrentCUDAStream().stream();
constexpr int num_warps_per_block = 32;
const int grid_size = ceil_div(num_tokens * NUM_HEAD_CHUNKS, num_warps_per_block);
const int block_size = num_warps_per_block * 32;
concat_mla_k_kernel<<<grid_size, block_size, 0, stream>>>(
reinterpret_cast<nv_bfloat16*>(k.data_ptr()),
reinterpret_cast<nv_bfloat16*>(k_nope.data_ptr()),
reinterpret_cast<nv_bfloat16*>(k_rope.data_ptr()),
num_tokens,
k.stride(0),
k.stride(1),
k_nope.stride(0),
k_nope.stride(1),
k_rope.stride(0));
cudaError_t err = cudaGetLastError();
TORCH_CHECK(err == cudaSuccess, "CUDA kernel launch failed: ", cudaGetErrorString(err));
}

View File

@@ -0,0 +1,58 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include <vector>
template <int N>
struct InputArray {
int values[N];
};
template <int N>
__global__ void copy_to_gpu_no_ce_kernel(const InputArray<N> input_array, int* output) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < N) {
output[idx] = input_array.values[idx];
}
}
template <int N>
void copy_to_gpu_no_ce_impl(const at::Tensor& input, at::Tensor& output) {
TORCH_CHECK(input.dim() == 1, "input must be 1-D");
TORCH_CHECK(static_cast<int>(input.numel()) == N, "input numel must equal template N");
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
TORCH_CHECK(input.dtype() == torch::kInt32, "input dtype must be int32");
TORCH_CHECK(output.dim() == 1, "output dim");
TORCH_CHECK(static_cast<int>(output.numel()) == N, "output size");
TORCH_CHECK(output.is_contiguous(), "output contiguous");
TORCH_CHECK(output.dtype() == torch::kInt32, "output dtype");
TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor");
TORCH_CHECK(output.device().is_cuda(), "output must be a CUDA tensor");
InputArray<N> input_array;
const int* input_ptr = input.data_ptr<int>();
for (int i = 0; i < N; ++i)
input_array.values[i] = input_ptr[i];
// may use multi thread blocks if performance bottleneck
dim3 grid(1);
dim3 block(static_cast<int>(input.numel()));
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
copy_to_gpu_no_ce_kernel<<<grid, block, 0, stream>>>(input_array, output.data_ptr<int>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
void copy_to_gpu_no_ce(const at::Tensor& input, at::Tensor& output) {
int N = static_cast<int>(input.numel());
// Can use macro if there are more N needed
if (N == 72) {
copy_to_gpu_no_ce_impl<72>(input, output);
} else if (N == 64) {
copy_to_gpu_no_ce_impl<64>(input, output);
} else {
TORCH_CHECK(false, "unexpected N");
}
}

View File

@@ -0,0 +1,59 @@
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <ATen/cuda/CUDAContext.h>
#include <flashinfer/norm.cuh>
#include "utils.h"
using namespace flashinfer;
void sgl_fused_add_rmsnorm(
torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps, bool enable_pdl) {
CHECK_INPUT(input);
CHECK_INPUT(residual);
CHECK_INPUT(weight);
auto device = input.device();
CHECK_EQ(residual.device(), device);
CHECK_EQ(weight.device(), device);
CHECK_DIM(2, input); // input: (batch_size, hidden_size)
CHECK_DIM(2, residual); // residual: (batch_size, hidden_size)
CHECK_DIM(1, weight); // weight: (hidden_size)
CHECK_EQ(input.size(0), residual.size(0));
CHECK_EQ(input.size(1), residual.size(1));
CHECK_EQ(input.size(1), weight.size(0));
unsigned int batch_size = input.size(0);
unsigned int hidden_size = input.size(1);
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
// support float16, bfloat16 and float32
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
cudaError_t status = norm::FusedAddRMSNorm(
static_cast<c_type*>(input.data_ptr()),
static_cast<c_type*>(residual.data_ptr()),
static_cast<c_type*>(weight.data_ptr()),
batch_size,
hidden_size,
input.stride(0),
residual.stride(0),
eps,
enable_pdl,
torch_current_stream);
TORCH_CHECK(
status == cudaSuccess, "FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
return true;
});
}

View File

@@ -0,0 +1,467 @@
/*
* Copyright (c) 2023 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef SGL_POS_ENC_CUH_
#define SGL_POS_ENC_CUH_
#include <flashinfer/pos_enc.cuh> // upstream
namespace flashinfer {
namespace kv_buffer_saver {
template <typename DType, typename IdType, uint32_t vec_size>
__device__ __forceinline__ void prepare(
vec_t<float, vec_size>& v_vec,
IdType& kv_cache_offset,
DType* v,
IdType* kv_cache_loc,
uint32_t idx,
uint32_t tx,
uint32_t kv_head_idx,
size_t v_stride_n,
size_t v_stride_h) {
kv_cache_offset = kv_cache_loc[idx];
DType* v_ptr = v + get_elem_offset_impl(idx, kv_head_idx, 0, v_stride_n, v_stride_h);
v_vec.cast_load(v_ptr + tx * vec_size);
}
template <typename DType, typename IdType, uint32_t vec_size>
__device__ __forceinline__ void save(
IdType& kv_cache_offset,
vec_t<float, vec_size>& k_vec,
vec_t<float, vec_size>& v_vec,
DType* k_buffer,
DType* v_buffer,
uint32_t idx,
uint32_t tx,
uint32_t kv_head_idx,
size_t k_buffer_stride_n,
size_t k_buffer_stride_h,
size_t v_buffer_stride_n,
size_t v_buffer_stride_h) {
DType* k_buffer_ptr =
k_buffer + get_elem_offset_impl(kv_cache_offset, kv_head_idx, 0, k_buffer_stride_n, k_buffer_stride_h);
DType* v_buffer_ptr =
v_buffer + get_elem_offset_impl(kv_cache_offset, kv_head_idx, 0, v_buffer_stride_n, v_buffer_stride_h);
k_vec.cast_store(k_buffer_ptr + tx * vec_size);
v_vec.cast_store(v_buffer_ptr + tx * vec_size);
}
} // namespace kv_buffer_saver
template <
bool save_kv_cache,
bool interleave,
uint32_t head_dim,
uint32_t vec_size,
uint32_t bdx,
typename DType,
typename IdType>
__global__ void BatchQKApplyRotaryPosIdsCosSinCacheEnhancedHeadParallelismKernel(
DType* q,
DType* k,
DType* v,
DType* q_rope,
DType* k_rope,
DType* k_buffer,
DType* v_buffer,
float* __restrict__ cos_sin_cache,
IdType* __restrict__ pos_ids,
uint32_t nnz,
uint32_t num_qo_heads,
uint32_t num_kv_heads,
uint32_t rotary_dim,
size_t q_stride_n,
size_t q_stride_h,
size_t k_stride_n,
size_t k_stride_h,
size_t v_stride_n,
size_t v_stride_h,
size_t q_rope_stride_n,
size_t q_rope_stride_h,
size_t k_rope_stride_n,
size_t k_rope_stride_h,
size_t k_buffer_stride_n,
size_t k_buffer_stride_h,
size_t v_buffer_stride_n,
size_t v_buffer_stride_h,
IdType* __restrict__ kv_cache_loc) {
uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y;
uint32_t by = blockIdx.y;
const uint32_t bdy = blockDim.y;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif
vec_t<float, vec_size> cos, sin;
if (bx * bdy + ty < nnz) {
const uint32_t idx = bx * bdy + ty;
const IdType pos = pos_ids[idx];
const int half_rotary_dim = rotary_dim / 2;
// 1. if interleave:
// - cos = cos_sin_cache[pos_id][tx * vec_size // 2]
// - sin = cos_sin_cache[pos_id][(rot_dim // 2) + tx * vec_size // 2]
// 2. if not interleave
// - cos = cos_cache[pos_id][(tx * vec_size) % (rot_dim // 2)]
// - sin = sin_cache[pos_id][(rot_dim // 2) + (tx * vec_size) % (rot_dim // 2)]
if (tx * vec_size < rotary_dim) {
int sin_offset = rotary_dim / 2;
int vec_idx;
if constexpr (interleave) {
vec_idx = (tx * vec_size) / 2; // Force integer division
} else {
vec_idx = (tx * vec_size) % half_rotary_dim; // Use half_rotary_dim
}
cos.load(cos_sin_cache + (pos * rotary_dim) + vec_idx);
sin.load(cos_sin_cache + (pos * rotary_dim) + (sin_offset + vec_idx));
}
if (by < num_qo_heads) {
uint32_t qo_head_idx = by;
DType* q_ptr = q + get_elem_offset_impl(idx, qo_head_idx, 0, q_stride_n, q_stride_h);
DType* q_rope_ptr = q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h);
vec_t<float, vec_size> q_vec;
if constexpr (interleave) {
q_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
} else {
q_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
}
q_vec.cast_store(q_rope_ptr + tx * vec_size);
} else {
uint32_t kv_head_idx = by - num_qo_heads;
DType* k_ptr = k + get_elem_offset_impl(idx, kv_head_idx, 0, k_stride_n, k_stride_h);
DType* k_rope_ptr = k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h);
vec_t<float, vec_size> v_vec;
IdType kv_cache_offset;
if constexpr (save_kv_cache) {
kv_buffer_saver::prepare<DType, IdType, vec_size>(
v_vec, kv_cache_offset, v, kv_cache_loc, idx, tx, kv_head_idx, v_stride_n, v_stride_h);
}
vec_t<float, vec_size> k_vec;
if constexpr (interleave) {
k_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
} else {
k_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
}
k_vec.cast_store(k_rope_ptr + tx * vec_size);
if constexpr (save_kv_cache) {
kv_buffer_saver::save<DType, IdType, vec_size>(
kv_cache_offset,
k_vec,
v_vec,
k_buffer,
v_buffer,
idx,
tx,
kv_head_idx,
k_buffer_stride_n,
k_buffer_stride_h,
v_buffer_stride_n,
v_buffer_stride_h);
}
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
}
template <
bool save_kv_cache,
bool interleave,
uint32_t head_dim,
uint32_t vec_size,
uint32_t bdx,
typename DType,
typename IdType>
__global__ void BatchQKApplyRotaryPosIdsCosSinCacheEnhancedKernel(
DType* q,
DType* k,
DType* v,
DType* q_rope,
DType* k_rope,
DType* k_buffer,
DType* v_buffer,
float* __restrict__ cos_sin_cache,
IdType* __restrict__ pos_ids,
uint32_t nnz,
uint32_t num_qo_heads,
uint32_t num_kv_heads,
uint32_t rotary_dim,
size_t q_stride_n,
size_t q_stride_h,
size_t k_stride_n,
size_t k_stride_h,
size_t v_stride_n,
size_t v_stride_h,
size_t q_rope_stride_n,
size_t q_rope_stride_h,
size_t k_rope_stride_n,
size_t k_rope_stride_h,
size_t k_buffer_stride_n,
size_t k_buffer_stride_h,
size_t v_buffer_stride_n,
size_t v_buffer_stride_h,
IdType* __restrict__ kv_cache_loc) {
uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y;
const uint32_t bdy = blockDim.y;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif
vec_t<float, vec_size> cos, sin;
if (bx * bdy + ty < nnz) {
const uint32_t idx = bx * bdy + ty;
const IdType pos = pos_ids[idx];
const int half_rotary_dim = rotary_dim / 2;
// 1. if interleave:
// - cos = cos_sin_cache[pos_id][tx * vec_size // 2]
// - sin = cos_sin_cache[pos_id][(rot_dim // 2) + tx * vec_size // 2]
// 2. if not interleave
// - cos = cos_cache[pos_id][(tx * vec_size) % (rot_dim // 2)]
// - sin = sin_cache[pos_id][(rot_dim // 2) + (tx * vec_size) % (rot_dim // 2)]
if (tx * vec_size < rotary_dim) {
int sin_offset = rotary_dim / 2;
int vec_idx;
if constexpr (interleave) {
vec_idx = (tx * vec_size) / 2; // Force integer division
} else {
vec_idx = (tx * vec_size) % half_rotary_dim; // Use half_rotary_dim
}
cos.load(cos_sin_cache + (pos * rotary_dim) + vec_idx);
sin.load(cos_sin_cache + (pos * rotary_dim) + (sin_offset + vec_idx));
}
// not to unroll the loop, because num head might be large and might lead to worse performance
#pragma unroll 1
for (uint32_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) {
DType* q_ptr = q + get_elem_offset_impl(idx, qo_head_idx, 0, q_stride_n, q_stride_h);
DType* q_rope_ptr = q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h);
vec_t<float, vec_size> q_vec;
if constexpr (interleave) {
q_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
} else {
q_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
}
q_vec.cast_store(q_rope_ptr + tx * vec_size);
}
#pragma unroll 1
for (uint32_t kv_head_idx = 0; kv_head_idx < num_kv_heads; ++kv_head_idx) {
DType* k_ptr = k + get_elem_offset_impl(idx, kv_head_idx, 0, k_stride_n, k_stride_h);
DType* k_rope_ptr = k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h);
vec_t<float, vec_size> v_vec;
IdType kv_cache_offset;
if constexpr (save_kv_cache) {
kv_buffer_saver::prepare<DType, IdType, vec_size>(
v_vec, kv_cache_offset, v, kv_cache_loc, idx, tx, kv_head_idx, v_stride_n, v_stride_h);
}
vec_t<float, vec_size> k_vec;
if constexpr (interleave) {
k_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
} else {
k_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
}
k_vec.cast_store(k_rope_ptr + tx * vec_size);
if constexpr (save_kv_cache) {
kv_buffer_saver::save<DType, IdType, vec_size>(
kv_cache_offset,
k_vec,
v_vec,
k_buffer,
v_buffer,
idx,
tx,
kv_head_idx,
k_buffer_stride_n,
k_buffer_stride_h,
v_buffer_stride_n,
v_buffer_stride_h);
}
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
}
#define DISPATCH_SAVE_KV_CACHE(save_kv_cache, SAVE_KV_CACHE, ...) \
if (save_kv_cache) { \
const bool SAVE_KV_CACHE = true; \
__VA_ARGS__ \
} else { \
const bool SAVE_KV_CACHE = false; \
__VA_ARGS__ \
}
template <typename DType, typename IdType>
cudaError_t BatchQKApplyRotaryPosIdsCosSinCacheEnhanced(
DType* q,
DType* k,
DType* v,
DType* q_rope,
DType* k_rope,
DType* k_buffer,
DType* v_buffer,
float* cos_sin_cache,
IdType* pos_ids,
uint32_t nnz,
uint32_t num_qo_heads,
uint32_t num_kv_heads,
uint32_t rotary_dim,
uint32_t head_dim,
size_t q_stride_n,
size_t q_stride_h,
size_t k_stride_n,
size_t k_stride_h,
size_t v_stride_n,
size_t v_stride_h,
size_t q_rope_stride_n,
size_t q_rope_stride_h,
size_t k_rope_stride_n,
size_t k_rope_stride_h,
size_t k_buffer_stride_n,
size_t k_buffer_stride_h,
size_t v_buffer_stride_n,
size_t v_buffer_stride_h,
IdType* kv_cache_loc,
bool interleave,
bool save_kv_cache,
bool enable_pdl,
cudaStream_t stream = nullptr) {
int dev_id = 0;
int num_sms = 0;
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id));
#define LAUNCH_KERNEL_RAW(kernel_name) \
do { \
cudaLaunchConfig_t config = {}; \
config.gridDim = nblks; \
config.blockDim = nthrs; \
config.dynamicSmemBytes = 0; \
config.stream = stream; \
cudaLaunchAttribute attrs[1] = {}; \
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; \
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; \
config.numAttrs = 1; \
config.attrs = attrs; \
\
FLASHINFER_CUDA_CALL(cudaLaunchKernelEx( \
&config, \
kernel_name, \
q, \
k, \
v, \
q_rope, \
k_rope, \
k_buffer, \
v_buffer, \
cos_sin_cache, \
pos_ids, \
nnz, \
num_qo_heads, \
num_kv_heads, \
rotary_dim, \
q_stride_n, \
q_stride_h, \
k_stride_n, \
k_stride_h, \
v_stride_n, \
v_stride_h, \
q_rope_stride_n, \
q_rope_stride_h, \
k_rope_stride_n, \
k_rope_stride_h, \
k_buffer_stride_n, \
k_buffer_stride_h, \
v_buffer_stride_n, \
v_buffer_stride_h, \
kv_cache_loc)); \
} while (0)
DISPATCH_SAVE_KV_CACHE(save_kv_cache, SAVE_KV_CACHE, {
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
// operate on 16 Bytes at a time
constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32);
// how many threads needed per head_dim
constexpr uint32_t bdx = HEAD_DIM / vec_size;
// how many threads needed per block
uint32_t num_threads = std::max(128U, bdx);
// how many tokens can we process in a block
uint32_t bdy = num_threads / bdx;
// how many blocks needed to process all tokens
uint32_t nblks_x = (nnz + bdy - 1) / bdy;
auto kernel_0 = BatchQKApplyRotaryPosIdsCosSinCacheEnhancedKernel<
SAVE_KV_CACHE,
INTERLEAVE,
HEAD_DIM,
vec_size,
bdx,
DType,
IdType>;
int num_blocks_per_sm_0 = 0;
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&num_blocks_per_sm_0, kernel_0, num_threads, /*smem_size=*/0));
uint32_t num_ctas_0 = num_blocks_per_sm_0 * num_sms;
if ((nnz + bdy - 1) / bdy >= num_ctas_0) {
dim3 nblks(nblks_x);
dim3 nthrs(bdx, bdy);
LAUNCH_KERNEL_RAW(kernel_0);
} else {
dim3 nblks(nblks_x, num_qo_heads + num_kv_heads);
dim3 nthrs(bdx, bdy);
auto kernel_1 = BatchQKApplyRotaryPosIdsCosSinCacheEnhancedHeadParallelismKernel<
SAVE_KV_CACHE,
INTERLEAVE,
HEAD_DIM,
vec_size,
bdx,
DType,
IdType>;
LAUNCH_KERNEL_RAW(kernel_1);
}
});
});
});
#undef LAUNCH_KERNEL_RAW
return cudaSuccess;
}
} // namespace flashinfer
#endif // SGL_POS_ENC_CUH_

View File

@@ -0,0 +1,164 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pos_enc.cuh"
#include "pytorch_extension_utils.h"
using namespace flashinfer;
void apply_rope_pos_ids_cos_sin_cache(
at::Tensor q,
at::Tensor k,
at::Tensor q_rope,
at::Tensor k_rope,
at::Tensor cos_sin_cache,
at::Tensor pos_ids,
bool interleave,
bool enable_pdl,
int64_t cuda_stream,
const std::optional<at::Tensor>& v,
const std::optional<at::Tensor>& k_buffer,
const std::optional<at::Tensor>& v_buffer,
const std::optional<at::Tensor>& kv_cache_loc) {
CHECK_LAST_DIM_CONTIGUOUS(q);
CHECK_LAST_DIM_CONTIGUOUS(k);
const bool save_kv_cache = v.has_value();
if (save_kv_cache) {
TORCH_CHECK(v.has_value());
TORCH_CHECK(k_buffer.has_value());
TORCH_CHECK(v_buffer.has_value());
TORCH_CHECK(kv_cache_loc.has_value());
CHECK_LAST_DIM_CONTIGUOUS(v.value());
CHECK_LAST_DIM_CONTIGUOUS(k_buffer.value());
CHECK_LAST_DIM_CONTIGUOUS(v_buffer.value());
CHECK_DIM(3, k_buffer.value()); // k_buffer: (nnz, H_K, D)
CHECK_DIM(3, v_buffer.value()); // v_buffer: (nnz, H_V, D)
CHECK_DIM(3, v.value()); // v: (nnz, H_V, D)
CHECK_DIM(1, kv_cache_loc.value()); // v: (n)
CHECK_INPUT(kv_cache_loc.value());
}
size_t k_buffer_stride_n = save_kv_cache ? k_buffer->stride(0) : 0;
size_t k_buffer_stride_h = save_kv_cache ? k_buffer->stride(1) : 0;
size_t v_buffer_stride_n = save_kv_cache ? v_buffer->stride(0) : 0;
size_t v_buffer_stride_h = save_kv_cache ? v_buffer->stride(1) : 0;
size_t v_stride_n = save_kv_cache ? v->stride(0) : 0;
size_t v_stride_h = save_kv_cache ? v->stride(1) : 0;
auto kv_cache_loc_ptr = save_kv_cache ? static_cast<int64_t*>(kv_cache_loc->data_ptr()) : nullptr;
CHECK_INPUT(cos_sin_cache);
CHECK_INPUT(pos_ids);
auto device = q.device();
CHECK_EQ(k.device(), device);
CHECK_EQ(cos_sin_cache.device(), device);
CHECK_EQ(pos_ids.device(), device);
CHECK_DIM(3, q); // q: (nnz, H_Q, D)
CHECK_DIM(3, k); // k: (nnz, H_K, D)
// cos_sin_cache: (max_seq_len, R)
// First half of R is cos, second half is sin
CHECK_DIM(2, cos_sin_cache);
CHECK_EQ(q.size(0), k.size(0));
CHECK_EQ(q.size(2), k.size(2));
unsigned int rotary_dim = cos_sin_cache.size(1);
unsigned int num_qo_heads = q.size(1);
unsigned int num_kv_heads = k.size(1);
unsigned int head_dim = q.size(2);
unsigned int nnz = q.size(0);
size_t q_stride_n = q.stride(0);
size_t q_stride_h = q.stride(1);
size_t k_stride_n = k.stride(0);
size_t k_stride_h = k.stride(1);
size_t q_rope_stride_n = q_rope.stride(0);
size_t q_rope_stride_h = q_rope.stride(1);
size_t k_rope_stride_n = k_rope.stride(0);
size_t k_rope_stride_h = k_rope.stride(1);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
// TODO temporarily only use `BatchQKApplyRotaryPosIdsCosSinCacheEnhanced` when save_kv_cache
// to avoid changing original code path; but this branch is feature-complete and should switch to this later
if (save_kv_cache) {
cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCacheEnhanced(
static_cast<c_type*>(q.data_ptr()),
static_cast<c_type*>(k.data_ptr()),
save_kv_cache ? static_cast<c_type*>(v->data_ptr()) : nullptr,
static_cast<c_type*>(q_rope.data_ptr()),
static_cast<c_type*>(k_rope.data_ptr()),
save_kv_cache ? static_cast<c_type*>(k_buffer->data_ptr()) : nullptr,
save_kv_cache ? static_cast<c_type*>(v_buffer->data_ptr()) : nullptr,
static_cast<float*>(cos_sin_cache.data_ptr()),
static_cast<int64_t*>(pos_ids.data_ptr()),
nnz,
num_qo_heads,
num_kv_heads,
rotary_dim,
head_dim,
q_stride_n,
q_stride_h,
k_stride_n,
k_stride_h,
v_stride_n,
v_stride_h,
q_rope_stride_n,
q_rope_stride_h,
k_rope_stride_n,
k_rope_stride_h,
k_buffer_stride_n,
k_buffer_stride_h,
v_buffer_stride_n,
v_buffer_stride_h,
kv_cache_loc_ptr,
interleave,
save_kv_cache,
enable_pdl,
stream);
TORCH_CHECK(
status == cudaSuccess,
"BatchQKApplyRotaryPosIdsCosSinCacheEnhanced failed with error code " +
std::string(cudaGetErrorString(status)));
} else {
TORCH_CHECK(!enable_pdl);
cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCache(
static_cast<c_type*>(q.data_ptr()),
static_cast<c_type*>(k.data_ptr()),
static_cast<c_type*>(q_rope.data_ptr()),
static_cast<c_type*>(k_rope.data_ptr()),
static_cast<float*>(cos_sin_cache.data_ptr()),
static_cast<int64_t*>(pos_ids.data_ptr()),
nnz,
num_qo_heads,
num_kv_heads,
rotary_dim,
head_dim,
q_stride_n,
q_stride_h,
k_stride_n,
k_stride_h,
q_rope_stride_n,
q_rope_stride_h,
k_rope_stride_n,
k_rope_stride_h,
interleave,
stream);
TORCH_CHECK(
status == cudaSuccess,
"BatchQKApplyRotaryPosIdsCosSinCache failed with error code " + std::string(cudaGetErrorString(status)));
}
return true;
});
}