adapt to sglang v0.5.2rc1 on dcu

This commit is contained in:
maxiao
2025-09-04 15:56:33 +08:00
commit 909abb58f5
2320 changed files with 489411 additions and 0 deletions

View File

@@ -0,0 +1,170 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#ifndef USE_ROCM
#include <flashinfer/activation.cuh>
#include "utils.h"
#else
#include "hip/hip_act_and_mul.cuh"
#endif
// Adapted from flashinfer activation
// https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/csrc/activation.cu#L44
namespace detail {
template <typename T>
__device__ __forceinline__ float to_f32(const T& x) {
#if USE_ROCM
return castToFloat(x);
#else
return static_cast<float>(x);
#endif
}
template <typename T>
__device__ __forceinline__ T from_f32(float f32) {
#if USE_ROCM
return castFromFloat<T>(f32);
#else
return static_cast<T>(f32);
#endif
}
} // namespace detail
template <typename T>
__device__ __forceinline__ T silu(const T& x) {
float f32_val = detail::to_f32(x);
return detail::from_f32<T>(f32_val / (1.0f + expf(-f32_val)));
}
template <typename T>
__device__ __forceinline__ T gelu(const T& x) {
constexpr float kAlpha = M_SQRT1_2;
float f32_val = detail::to_f32(x);
return detail::from_f32<T>(f32_val * (0.5f * (1.0f + erf(f32_val * kAlpha))));
}
// gelu_quick(x) = x * torch.sigmoid(1.702 * x)
template <typename T>
__device__ __forceinline__ T gelu_quick_act(const T& x) {
float f32_val = detail::to_f32(x);
return detail::from_f32<T>(f32_val / (1.0f + expf(-f32_val * 1.702f)));
}
template <typename T>
__device__ __forceinline__ T gelu_tanh(const T& x) {
constexpr float kAlpha = 0.044715f;
constexpr float kBeta = 0.7978845608028654f;
float f32_val = detail::to_f32(x);
const float cdf = 0.5f * (1.0f + tanhf((kBeta * (f32_val + kAlpha * f32_val * f32_val * f32_val))));
return detail::from_f32<T>(f32_val * cdf);
}
void silu_and_mul(at::Tensor& out, at::Tensor& input) {
int d = input.size(-1) / 2;
int64_t num_tokens = input.numel() / input.size(-1);
dim3 grid(num_tokens);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
uint32_t vec_size = 16 / sizeof(c_type);
dim3 block(std::min(d / vec_size, 1024U));
#if USE_ROCM
sgl_hip::activation::act_and_mul_kernel<c_type, silu>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
#else
flashinfer::activation::act_and_mul_kernel<c_type, silu>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
#endif
return true;
});
}
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input) {
int d = input.size(-1) / 2;
int64_t num_tokens = input.numel() / input.size(-1);
dim3 grid(num_tokens);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
uint32_t vec_size = 16 / sizeof(c_type);
dim3 block(std::min(d / vec_size, 1024U));
#if USE_ROCM
sgl_hip::activation::act_and_mul_kernel<c_type, gelu_tanh>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
#else
flashinfer::activation::act_and_mul_kernel<c_type, gelu_tanh>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
#endif
return true;
});
}
void gelu_and_mul(at::Tensor& out, at::Tensor& input) {
int d = input.size(-1) / 2;
int64_t num_tokens = input.numel() / input.size(-1);
dim3 grid(num_tokens);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
uint32_t vec_size = 16 / sizeof(c_type);
dim3 block(std::min(d / vec_size, 1024U));
#if USE_ROCM
sgl_hip::activation::act_and_mul_kernel<c_type, gelu>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
#else
flashinfer::activation::act_and_mul_kernel<c_type, gelu>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
#endif
return true;
});
}
#if USE_ROCM
void gelu_quick(at::Tensor& out, const at::Tensor& input) {
int d = input.size(-1);
int64_t num_tokens = input.numel() / input.size(-1);
dim3 grid(num_tokens);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
uint32_t vec_size = 16 / sizeof(c_type);
dim3 block(std::min(d / vec_size, 1024U));
sgl_hip::activation::act_only_kernel<c_type, gelu_quick_act>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
return true;
});
}
#endif

View File

@@ -0,0 +1,171 @@
#include "pytorch_extension_utils.h"
template <typename T>
struct ConvertToFP8 {
static __device__ __nv_fp8_storage_t convert_to_fp8(T value) {
return 0;
}
};
template <>
struct ConvertToFP8<__nv_bfloat16> {
static __device__ __nv_fp8_storage_t convert_to_fp8(__nv_bfloat16 value) {
return __nv_cvt_bfloat16raw_to_fp8(value, __NV_SATFINITE, __NV_E4M3);
}
};
template <>
struct ConvertToFP8<half> {
static __device__ __nv_fp8_storage_t convert_to_fp8(half value) {
return __nv_cvt_halfraw_to_fp8(value, __NV_SATFINITE, __NV_E4M3);
}
};
template <typename T>
struct ConvertFromFloat {
static __device__ T convert_from_float(float value) {
return 0;
}
};
template <>
struct ConvertFromFloat<__nv_bfloat16> {
static __device__ __nv_bfloat16 convert_from_float(float value) {
return __float2bfloat16(value);
}
};
template <>
struct ConvertFromFloat<half> {
static __device__ half convert_from_float(float value) {
return __float2half(value);
}
};
template <typename T>
__global__ void fused_downcast_kernel(
const T* cache_k,
const T* cache_v,
const float* k_scale,
const float* v_scale,
__nv_fp8_storage_t* output_k,
__nv_fp8_storage_t* output_v,
const int input_sl,
const int head,
const int dim,
const T max_fp8,
const T min_fp8,
const int64_t mult,
const int64_t offset,
const int64_t* loc) {
// TODO: change name
int token_idx = blockIdx.x;
int thread_idx = threadIdx.x;
int total_threads = blockDim.x;
T k_scale_val = ConvertFromFloat<T>::convert_from_float(k_scale[0]);
T v_scale_val = ConvertFromFloat<T>::convert_from_float(v_scale[0]);
T k_scale_inv = static_cast<T>(1.f) / k_scale_val;
T v_scale_inv = static_cast<T>(1.f) / v_scale_val;
auto clamp = [&](T val) { return val > max_fp8 ? max_fp8 : (min_fp8 > val ? min_fp8 : val); };
if (token_idx < input_sl) {
int out_seq_idx = loc[token_idx];
#pragma unroll
for (int i = thread_idx; i < head * dim; i += total_threads) {
int in_idx = token_idx * head * dim + i;
int out_idx = (out_seq_idx * mult + offset) * head * dim + i;
T k_val = cache_k[in_idx] * k_scale_inv;
k_val = clamp(k_val);
output_k[out_idx] = ConvertToFP8<T>::convert_to_fp8(k_val);
T v_val = cache_v[in_idx] * v_scale_inv;
v_val = clamp(v_val);
output_v[out_idx] = ConvertToFP8<T>::convert_to_fp8(v_val);
}
}
}
template <typename T>
void downcast_fp8_impl(
at::Tensor& k,
at::Tensor& v,
at::Tensor& k_out,
at::Tensor& v_out,
at::Tensor& k_scale,
at::Tensor& v_scale,
at::Tensor& loc,
int64_t mult,
int64_t offset,
cudaStream_t stream) {
CHECK_INPUT(k);
CHECK_INPUT(v);
CHECK_INPUT(k_out);
CHECK_INPUT(v_out);
CHECK_INPUT(k_scale);
CHECK_INPUT(v_scale);
CHECK_INPUT(loc);
int64_t input_sl = k.size(0);
int64_t head = k.size(1);
int64_t dim = k.size(2);
dim3 grid(input_sl * head);
int vec_size = 8;
dim3 block(std::min(int(dim) / vec_size, 1024));
const T max_fp8 = static_cast<T>(448.0f);
const T min_fp8 = static_cast<T>(-448.0f);
fused_downcast_kernel<T><<<grid, block, 0, stream>>>(
static_cast<const T*>(k.data_ptr()),
static_cast<const T*>(v.data_ptr()),
static_cast<const float*>(k_scale.data_ptr()),
static_cast<const float*>(v_scale.data_ptr()),
static_cast<__nv_fp8_storage_t*>(k_out.data_ptr()),
static_cast<__nv_fp8_storage_t*>(v_out.data_ptr()),
input_sl,
head,
dim,
max_fp8,
min_fp8,
mult,
offset,
static_cast<const int64_t*>(loc.data_ptr()));
cudaError_t status = cudaGetLastError();
TORCH_CHECK(status == cudaSuccess, "Kernel launch failed: " + std::string(cudaGetErrorString(status)));
}
void downcast_fp8(
at::Tensor& k,
at::Tensor& v,
at::Tensor& k_out,
at::Tensor& v_out,
at::Tensor& k_scale,
at::Tensor& v_scale,
at::Tensor& loc,
int64_t mult,
int64_t offset,
int64_t cuda_stream) {
CHECK_INPUT(k);
CHECK_INPUT(v);
CHECK_INPUT(k_out);
CHECK_INPUT(v_out);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
switch (k.scalar_type()) {
case at::ScalarType::BFloat16:
downcast_fp8_impl<__nv_bfloat16>(k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, stream);
break;
case at::ScalarType::Half:
downcast_fp8_impl<__half>(k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, stream);
break;
default:
TORCH_CHECK(false, "Unsupported input type for downcast_fp8. Expected bfloat16 or float16.");
}
}

View File

@@ -0,0 +1,59 @@
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <ATen/cuda/CUDAContext.h>
#include <flashinfer/norm.cuh>
#include "utils.h"
using namespace flashinfer;
void sgl_fused_add_rmsnorm(
torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps, bool enable_pdl) {
CHECK_INPUT(input);
CHECK_INPUT(residual);
CHECK_INPUT(weight);
auto device = input.device();
CHECK_EQ(residual.device(), device);
CHECK_EQ(weight.device(), device);
CHECK_DIM(2, input); // input: (batch_size, hidden_size)
CHECK_DIM(2, residual); // residual: (batch_size, hidden_size)
CHECK_DIM(1, weight); // weight: (hidden_size)
CHECK_EQ(input.size(0), residual.size(0));
CHECK_EQ(input.size(1), residual.size(1));
CHECK_EQ(input.size(1), weight.size(0));
unsigned int batch_size = input.size(0);
unsigned int hidden_size = input.size(1);
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
// support float16, bfloat16 and float32
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
cudaError_t status = norm::FusedAddRMSNorm(
static_cast<c_type*>(input.data_ptr()),
static_cast<c_type*>(residual.data_ptr()),
static_cast<c_type*>(weight.data_ptr()),
batch_size,
hidden_size,
input.stride(0),
residual.stride(0),
eps,
enable_pdl,
torch_current_stream);
TORCH_CHECK(
status == cudaSuccess, "FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
return true;
});
}

View File

@@ -0,0 +1,467 @@
/*
* Copyright (c) 2023 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef SGL_POS_ENC_CUH_
#define SGL_POS_ENC_CUH_
#include <flashinfer/pos_enc.cuh> // upstream
namespace flashinfer {
namespace kv_buffer_saver {
template <typename DType, typename IdType, uint32_t vec_size>
__device__ __forceinline__ void prepare(
vec_t<float, vec_size>& v_vec,
IdType& kv_cache_offset,
DType* v,
IdType* kv_cache_loc,
uint32_t idx,
uint32_t tx,
uint32_t kv_head_idx,
size_t v_stride_n,
size_t v_stride_h) {
kv_cache_offset = kv_cache_loc[idx];
DType* v_ptr = v + get_elem_offset_impl(idx, kv_head_idx, 0, v_stride_n, v_stride_h);
v_vec.cast_load(v_ptr + tx * vec_size);
}
template <typename DType, typename IdType, uint32_t vec_size>
__device__ __forceinline__ void save(
IdType& kv_cache_offset,
vec_t<float, vec_size>& k_vec,
vec_t<float, vec_size>& v_vec,
DType* k_buffer,
DType* v_buffer,
uint32_t idx,
uint32_t tx,
uint32_t kv_head_idx,
size_t k_buffer_stride_n,
size_t k_buffer_stride_h,
size_t v_buffer_stride_n,
size_t v_buffer_stride_h) {
DType* k_buffer_ptr =
k_buffer + get_elem_offset_impl(kv_cache_offset, kv_head_idx, 0, k_buffer_stride_n, k_buffer_stride_h);
DType* v_buffer_ptr =
v_buffer + get_elem_offset_impl(kv_cache_offset, kv_head_idx, 0, v_buffer_stride_n, v_buffer_stride_h);
k_vec.cast_store(k_buffer_ptr + tx * vec_size);
v_vec.cast_store(v_buffer_ptr + tx * vec_size);
}
} // namespace kv_buffer_saver
template <
bool save_kv_cache,
bool interleave,
uint32_t head_dim,
uint32_t vec_size,
uint32_t bdx,
typename DType,
typename IdType>
__global__ void BatchQKApplyRotaryPosIdsCosSinCacheEnhancedHeadParallelismKernel(
DType* q,
DType* k,
DType* v,
DType* q_rope,
DType* k_rope,
DType* k_buffer,
DType* v_buffer,
float* __restrict__ cos_sin_cache,
IdType* __restrict__ pos_ids,
uint32_t nnz,
uint32_t num_qo_heads,
uint32_t num_kv_heads,
uint32_t rotary_dim,
size_t q_stride_n,
size_t q_stride_h,
size_t k_stride_n,
size_t k_stride_h,
size_t v_stride_n,
size_t v_stride_h,
size_t q_rope_stride_n,
size_t q_rope_stride_h,
size_t k_rope_stride_n,
size_t k_rope_stride_h,
size_t k_buffer_stride_n,
size_t k_buffer_stride_h,
size_t v_buffer_stride_n,
size_t v_buffer_stride_h,
IdType* __restrict__ kv_cache_loc) {
uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y;
uint32_t by = blockIdx.y;
const uint32_t bdy = blockDim.y;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif
vec_t<float, vec_size> cos, sin;
if (bx * bdy + ty < nnz) {
const uint32_t idx = bx * bdy + ty;
const IdType pos = pos_ids[idx];
const int half_rotary_dim = rotary_dim / 2;
// 1. if interleave:
// - cos = cos_sin_cache[pos_id][tx * vec_size // 2]
// - sin = cos_sin_cache[pos_id][(rot_dim // 2) + tx * vec_size // 2]
// 2. if not interleave
// - cos = cos_cache[pos_id][(tx * vec_size) % (rot_dim // 2)]
// - sin = sin_cache[pos_id][(rot_dim // 2) + (tx * vec_size) % (rot_dim // 2)]
if (tx * vec_size < rotary_dim) {
int sin_offset = rotary_dim / 2;
int vec_idx;
if constexpr (interleave) {
vec_idx = (tx * vec_size) / 2; // Force integer division
} else {
vec_idx = (tx * vec_size) % half_rotary_dim; // Use half_rotary_dim
}
cos.load(cos_sin_cache + (pos * rotary_dim) + vec_idx);
sin.load(cos_sin_cache + (pos * rotary_dim) + (sin_offset + vec_idx));
}
if (by < num_qo_heads) {
uint32_t qo_head_idx = by;
DType* q_ptr = q + get_elem_offset_impl(idx, qo_head_idx, 0, q_stride_n, q_stride_h);
DType* q_rope_ptr = q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h);
vec_t<float, vec_size> q_vec;
if constexpr (interleave) {
q_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
} else {
q_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
}
q_vec.cast_store(q_rope_ptr + tx * vec_size);
} else {
uint32_t kv_head_idx = by - num_qo_heads;
DType* k_ptr = k + get_elem_offset_impl(idx, kv_head_idx, 0, k_stride_n, k_stride_h);
DType* k_rope_ptr = k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h);
vec_t<float, vec_size> v_vec;
IdType kv_cache_offset;
if constexpr (save_kv_cache) {
kv_buffer_saver::prepare<DType, IdType, vec_size>(
v_vec, kv_cache_offset, v, kv_cache_loc, idx, tx, kv_head_idx, v_stride_n, v_stride_h);
}
vec_t<float, vec_size> k_vec;
if constexpr (interleave) {
k_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
} else {
k_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
}
k_vec.cast_store(k_rope_ptr + tx * vec_size);
if constexpr (save_kv_cache) {
kv_buffer_saver::save<DType, IdType, vec_size>(
kv_cache_offset,
k_vec,
v_vec,
k_buffer,
v_buffer,
idx,
tx,
kv_head_idx,
k_buffer_stride_n,
k_buffer_stride_h,
v_buffer_stride_n,
v_buffer_stride_h);
}
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
}
template <
bool save_kv_cache,
bool interleave,
uint32_t head_dim,
uint32_t vec_size,
uint32_t bdx,
typename DType,
typename IdType>
__global__ void BatchQKApplyRotaryPosIdsCosSinCacheEnhancedKernel(
DType* q,
DType* k,
DType* v,
DType* q_rope,
DType* k_rope,
DType* k_buffer,
DType* v_buffer,
float* __restrict__ cos_sin_cache,
IdType* __restrict__ pos_ids,
uint32_t nnz,
uint32_t num_qo_heads,
uint32_t num_kv_heads,
uint32_t rotary_dim,
size_t q_stride_n,
size_t q_stride_h,
size_t k_stride_n,
size_t k_stride_h,
size_t v_stride_n,
size_t v_stride_h,
size_t q_rope_stride_n,
size_t q_rope_stride_h,
size_t k_rope_stride_n,
size_t k_rope_stride_h,
size_t k_buffer_stride_n,
size_t k_buffer_stride_h,
size_t v_buffer_stride_n,
size_t v_buffer_stride_h,
IdType* __restrict__ kv_cache_loc) {
uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y;
const uint32_t bdy = blockDim.y;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif
vec_t<float, vec_size> cos, sin;
if (bx * bdy + ty < nnz) {
const uint32_t idx = bx * bdy + ty;
const IdType pos = pos_ids[idx];
const int half_rotary_dim = rotary_dim / 2;
// 1. if interleave:
// - cos = cos_sin_cache[pos_id][tx * vec_size // 2]
// - sin = cos_sin_cache[pos_id][(rot_dim // 2) + tx * vec_size // 2]
// 2. if not interleave
// - cos = cos_cache[pos_id][(tx * vec_size) % (rot_dim // 2)]
// - sin = sin_cache[pos_id][(rot_dim // 2) + (tx * vec_size) % (rot_dim // 2)]
if (tx * vec_size < rotary_dim) {
int sin_offset = rotary_dim / 2;
int vec_idx;
if constexpr (interleave) {
vec_idx = (tx * vec_size) / 2; // Force integer division
} else {
vec_idx = (tx * vec_size) % half_rotary_dim; // Use half_rotary_dim
}
cos.load(cos_sin_cache + (pos * rotary_dim) + vec_idx);
sin.load(cos_sin_cache + (pos * rotary_dim) + (sin_offset + vec_idx));
}
// not to unroll the loop, because num head might be large and might lead to worse performance
#pragma unroll 1
for (uint32_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) {
DType* q_ptr = q + get_elem_offset_impl(idx, qo_head_idx, 0, q_stride_n, q_stride_h);
DType* q_rope_ptr = q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h);
vec_t<float, vec_size> q_vec;
if constexpr (interleave) {
q_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
} else {
q_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
}
q_vec.cast_store(q_rope_ptr + tx * vec_size);
}
#pragma unroll 1
for (uint32_t kv_head_idx = 0; kv_head_idx < num_kv_heads; ++kv_head_idx) {
DType* k_ptr = k + get_elem_offset_impl(idx, kv_head_idx, 0, k_stride_n, k_stride_h);
DType* k_rope_ptr = k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h);
vec_t<float, vec_size> v_vec;
IdType kv_cache_offset;
if constexpr (save_kv_cache) {
kv_buffer_saver::prepare<DType, IdType, vec_size>(
v_vec, kv_cache_offset, v, kv_cache_loc, idx, tx, kv_head_idx, v_stride_n, v_stride_h);
}
vec_t<float, vec_size> k_vec;
if constexpr (interleave) {
k_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
} else {
k_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
}
k_vec.cast_store(k_rope_ptr + tx * vec_size);
if constexpr (save_kv_cache) {
kv_buffer_saver::save<DType, IdType, vec_size>(
kv_cache_offset,
k_vec,
v_vec,
k_buffer,
v_buffer,
idx,
tx,
kv_head_idx,
k_buffer_stride_n,
k_buffer_stride_h,
v_buffer_stride_n,
v_buffer_stride_h);
}
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
}
#define DISPATCH_SAVE_KV_CACHE(save_kv_cache, SAVE_KV_CACHE, ...) \
if (save_kv_cache) { \
const bool SAVE_KV_CACHE = true; \
__VA_ARGS__ \
} else { \
const bool SAVE_KV_CACHE = false; \
__VA_ARGS__ \
}
template <typename DType, typename IdType>
cudaError_t BatchQKApplyRotaryPosIdsCosSinCacheEnhanced(
DType* q,
DType* k,
DType* v,
DType* q_rope,
DType* k_rope,
DType* k_buffer,
DType* v_buffer,
float* cos_sin_cache,
IdType* pos_ids,
uint32_t nnz,
uint32_t num_qo_heads,
uint32_t num_kv_heads,
uint32_t rotary_dim,
uint32_t head_dim,
size_t q_stride_n,
size_t q_stride_h,
size_t k_stride_n,
size_t k_stride_h,
size_t v_stride_n,
size_t v_stride_h,
size_t q_rope_stride_n,
size_t q_rope_stride_h,
size_t k_rope_stride_n,
size_t k_rope_stride_h,
size_t k_buffer_stride_n,
size_t k_buffer_stride_h,
size_t v_buffer_stride_n,
size_t v_buffer_stride_h,
IdType* kv_cache_loc,
bool interleave,
bool save_kv_cache,
bool enable_pdl,
cudaStream_t stream = nullptr) {
int dev_id = 0;
int num_sms = 0;
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id));
#define LAUNCH_KERNEL_RAW(kernel_name) \
do { \
cudaLaunchConfig_t config = {}; \
config.gridDim = nblks; \
config.blockDim = nthrs; \
config.dynamicSmemBytes = 0; \
config.stream = stream; \
cudaLaunchAttribute attrs[1] = {}; \
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; \
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; \
config.numAttrs = 1; \
config.attrs = attrs; \
\
FLASHINFER_CUDA_CALL(cudaLaunchKernelEx( \
&config, \
kernel_name, \
q, \
k, \
v, \
q_rope, \
k_rope, \
k_buffer, \
v_buffer, \
cos_sin_cache, \
pos_ids, \
nnz, \
num_qo_heads, \
num_kv_heads, \
rotary_dim, \
q_stride_n, \
q_stride_h, \
k_stride_n, \
k_stride_h, \
v_stride_n, \
v_stride_h, \
q_rope_stride_n, \
q_rope_stride_h, \
k_rope_stride_n, \
k_rope_stride_h, \
k_buffer_stride_n, \
k_buffer_stride_h, \
v_buffer_stride_n, \
v_buffer_stride_h, \
kv_cache_loc)); \
} while (0)
DISPATCH_SAVE_KV_CACHE(save_kv_cache, SAVE_KV_CACHE, {
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
// operate on 16 Bytes at a time
constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32);
// how many threads needed per head_dim
constexpr uint32_t bdx = HEAD_DIM / vec_size;
// how many threads needed per block
uint32_t num_threads = std::max(128U, bdx);
// how many tokens can we process in a block
uint32_t bdy = num_threads / bdx;
// how many blocks needed to process all tokens
uint32_t nblks_x = (nnz + bdy - 1) / bdy;
auto kernel_0 = BatchQKApplyRotaryPosIdsCosSinCacheEnhancedKernel<
SAVE_KV_CACHE,
INTERLEAVE,
HEAD_DIM,
vec_size,
bdx,
DType,
IdType>;
int num_blocks_per_sm_0 = 0;
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&num_blocks_per_sm_0, kernel_0, num_threads, /*smem_size=*/0));
uint32_t num_ctas_0 = num_blocks_per_sm_0 * num_sms;
if ((nnz + bdy - 1) / bdy >= num_ctas_0) {
dim3 nblks(nblks_x);
dim3 nthrs(bdx, bdy);
LAUNCH_KERNEL_RAW(kernel_0);
} else {
dim3 nblks(nblks_x, num_qo_heads + num_kv_heads);
dim3 nthrs(bdx, bdy);
auto kernel_1 = BatchQKApplyRotaryPosIdsCosSinCacheEnhancedHeadParallelismKernel<
SAVE_KV_CACHE,
INTERLEAVE,
HEAD_DIM,
vec_size,
bdx,
DType,
IdType>;
LAUNCH_KERNEL_RAW(kernel_1);
}
});
});
});
#undef LAUNCH_KERNEL_RAW
return cudaSuccess;
}
} // namespace flashinfer
#endif // SGL_POS_ENC_CUH_

View File

@@ -0,0 +1,164 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pos_enc.cuh"
#include "pytorch_extension_utils.h"
using namespace flashinfer;
void apply_rope_pos_ids_cos_sin_cache(
at::Tensor q,
at::Tensor k,
at::Tensor q_rope,
at::Tensor k_rope,
at::Tensor cos_sin_cache,
at::Tensor pos_ids,
bool interleave,
bool enable_pdl,
int64_t cuda_stream,
const std::optional<at::Tensor>& v,
const std::optional<at::Tensor>& k_buffer,
const std::optional<at::Tensor>& v_buffer,
const std::optional<at::Tensor>& kv_cache_loc) {
CHECK_LAST_DIM_CONTIGUOUS(q);
CHECK_LAST_DIM_CONTIGUOUS(k);
const bool save_kv_cache = v.has_value();
if (save_kv_cache) {
TORCH_CHECK(v.has_value());
TORCH_CHECK(k_buffer.has_value());
TORCH_CHECK(v_buffer.has_value());
TORCH_CHECK(kv_cache_loc.has_value());
CHECK_LAST_DIM_CONTIGUOUS(v.value());
CHECK_LAST_DIM_CONTIGUOUS(k_buffer.value());
CHECK_LAST_DIM_CONTIGUOUS(v_buffer.value());
CHECK_DIM(3, k_buffer.value()); // k_buffer: (nnz, H_K, D)
CHECK_DIM(3, v_buffer.value()); // v_buffer: (nnz, H_V, D)
CHECK_DIM(3, v.value()); // v: (nnz, H_V, D)
CHECK_DIM(1, kv_cache_loc.value()); // v: (n)
CHECK_INPUT(kv_cache_loc.value());
}
size_t k_buffer_stride_n = save_kv_cache ? k_buffer->stride(0) : 0;
size_t k_buffer_stride_h = save_kv_cache ? k_buffer->stride(1) : 0;
size_t v_buffer_stride_n = save_kv_cache ? v_buffer->stride(0) : 0;
size_t v_buffer_stride_h = save_kv_cache ? v_buffer->stride(1) : 0;
size_t v_stride_n = save_kv_cache ? v->stride(0) : 0;
size_t v_stride_h = save_kv_cache ? v->stride(1) : 0;
auto kv_cache_loc_ptr = save_kv_cache ? static_cast<int64_t*>(kv_cache_loc->data_ptr()) : nullptr;
CHECK_INPUT(cos_sin_cache);
CHECK_INPUT(pos_ids);
auto device = q.device();
CHECK_EQ(k.device(), device);
CHECK_EQ(cos_sin_cache.device(), device);
CHECK_EQ(pos_ids.device(), device);
CHECK_DIM(3, q); // q: (nnz, H_Q, D)
CHECK_DIM(3, k); // k: (nnz, H_K, D)
// cos_sin_cache: (max_seq_len, R)
// First half of R is cos, second half is sin
CHECK_DIM(2, cos_sin_cache);
CHECK_EQ(q.size(0), k.size(0));
CHECK_EQ(q.size(2), k.size(2));
unsigned int rotary_dim = cos_sin_cache.size(1);
unsigned int num_qo_heads = q.size(1);
unsigned int num_kv_heads = k.size(1);
unsigned int head_dim = q.size(2);
unsigned int nnz = q.size(0);
size_t q_stride_n = q.stride(0);
size_t q_stride_h = q.stride(1);
size_t k_stride_n = k.stride(0);
size_t k_stride_h = k.stride(1);
size_t q_rope_stride_n = q_rope.stride(0);
size_t q_rope_stride_h = q_rope.stride(1);
size_t k_rope_stride_n = k_rope.stride(0);
size_t k_rope_stride_h = k_rope.stride(1);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
// TODO temporarily only use `BatchQKApplyRotaryPosIdsCosSinCacheEnhanced` when save_kv_cache
// to avoid changing original code path; but this branch is feature-complete and should switch to this later
if (save_kv_cache) {
cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCacheEnhanced(
static_cast<c_type*>(q.data_ptr()),
static_cast<c_type*>(k.data_ptr()),
save_kv_cache ? static_cast<c_type*>(v->data_ptr()) : nullptr,
static_cast<c_type*>(q_rope.data_ptr()),
static_cast<c_type*>(k_rope.data_ptr()),
save_kv_cache ? static_cast<c_type*>(k_buffer->data_ptr()) : nullptr,
save_kv_cache ? static_cast<c_type*>(v_buffer->data_ptr()) : nullptr,
static_cast<float*>(cos_sin_cache.data_ptr()),
static_cast<int64_t*>(pos_ids.data_ptr()),
nnz,
num_qo_heads,
num_kv_heads,
rotary_dim,
head_dim,
q_stride_n,
q_stride_h,
k_stride_n,
k_stride_h,
v_stride_n,
v_stride_h,
q_rope_stride_n,
q_rope_stride_h,
k_rope_stride_n,
k_rope_stride_h,
k_buffer_stride_n,
k_buffer_stride_h,
v_buffer_stride_n,
v_buffer_stride_h,
kv_cache_loc_ptr,
interleave,
save_kv_cache,
enable_pdl,
stream);
TORCH_CHECK(
status == cudaSuccess,
"BatchQKApplyRotaryPosIdsCosSinCacheEnhanced failed with error code " +
std::string(cudaGetErrorString(status)));
} else {
TORCH_CHECK(!enable_pdl);
cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCache(
static_cast<c_type*>(q.data_ptr()),
static_cast<c_type*>(k.data_ptr()),
static_cast<c_type*>(q_rope.data_ptr()),
static_cast<c_type*>(k_rope.data_ptr()),
static_cast<float*>(cos_sin_cache.data_ptr()),
static_cast<int64_t*>(pos_ids.data_ptr()),
nnz,
num_qo_heads,
num_kv_heads,
rotary_dim,
head_dim,
q_stride_n,
q_stride_h,
k_stride_n,
k_stride_h,
q_rope_stride_n,
q_rope_stride_h,
k_rope_stride_n,
k_rope_stride_h,
interleave,
stream);
TORCH_CHECK(
status == cudaSuccess,
"BatchQKApplyRotaryPosIdsCosSinCache failed with error code " + std::string(cudaGetErrorString(status)));
}
return true;
});
}