adapt to sglang v0.5.2rc1 on dcu
This commit is contained in:
170
sgl-kernel/csrc/elementwise/activation.cu
Normal file
170
sgl-kernel/csrc/elementwise/activation.cu
Normal file
@@ -0,0 +1,170 @@
|
||||
/*
|
||||
* Copyright (c) 2024 by FlashInfer team.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
|
||||
#include <flashinfer/activation.cuh>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
#else
|
||||
#include "hip/hip_act_and_mul.cuh"
|
||||
#endif
|
||||
|
||||
// Adapted from flashinfer activation
|
||||
// https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/csrc/activation.cu#L44
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ float to_f32(const T& x) {
|
||||
#if USE_ROCM
|
||||
return castToFloat(x);
|
||||
#else
|
||||
return static_cast<float>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T from_f32(float f32) {
|
||||
#if USE_ROCM
|
||||
return castFromFloat<T>(f32);
|
||||
#else
|
||||
return static_cast<T>(f32);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T silu(const T& x) {
|
||||
float f32_val = detail::to_f32(x);
|
||||
return detail::from_f32<T>(f32_val / (1.0f + expf(-f32_val)));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T gelu(const T& x) {
|
||||
constexpr float kAlpha = M_SQRT1_2;
|
||||
float f32_val = detail::to_f32(x);
|
||||
return detail::from_f32<T>(f32_val * (0.5f * (1.0f + erf(f32_val * kAlpha))));
|
||||
}
|
||||
|
||||
// gelu_quick(x) = x * torch.sigmoid(1.702 * x)
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T gelu_quick_act(const T& x) {
|
||||
float f32_val = detail::to_f32(x);
|
||||
return detail::from_f32<T>(f32_val / (1.0f + expf(-f32_val * 1.702f)));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T gelu_tanh(const T& x) {
|
||||
constexpr float kAlpha = 0.044715f;
|
||||
constexpr float kBeta = 0.7978845608028654f;
|
||||
float f32_val = detail::to_f32(x);
|
||||
const float cdf = 0.5f * (1.0f + tanhf((kBeta * (f32_val + kAlpha * f32_val * f32_val * f32_val))));
|
||||
return detail::from_f32<T>(f32_val * cdf);
|
||||
}
|
||||
|
||||
void silu_and_mul(at::Tensor& out, at::Tensor& input) {
|
||||
int d = input.size(-1) / 2;
|
||||
int64_t num_tokens = input.numel() / input.size(-1);
|
||||
dim3 grid(num_tokens);
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
|
||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
|
||||
uint32_t vec_size = 16 / sizeof(c_type);
|
||||
dim3 block(std::min(d / vec_size, 1024U));
|
||||
#if USE_ROCM
|
||||
sgl_hip::activation::act_and_mul_kernel<c_type, silu>
|
||||
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
|
||||
#else
|
||||
flashinfer::activation::act_and_mul_kernel<c_type, silu>
|
||||
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
|
||||
#endif
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input) {
|
||||
int d = input.size(-1) / 2;
|
||||
int64_t num_tokens = input.numel() / input.size(-1);
|
||||
dim3 grid(num_tokens);
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
|
||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
|
||||
uint32_t vec_size = 16 / sizeof(c_type);
|
||||
dim3 block(std::min(d / vec_size, 1024U));
|
||||
#if USE_ROCM
|
||||
sgl_hip::activation::act_and_mul_kernel<c_type, gelu_tanh>
|
||||
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
|
||||
#else
|
||||
flashinfer::activation::act_and_mul_kernel<c_type, gelu_tanh>
|
||||
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
|
||||
#endif
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
void gelu_and_mul(at::Tensor& out, at::Tensor& input) {
|
||||
int d = input.size(-1) / 2;
|
||||
int64_t num_tokens = input.numel() / input.size(-1);
|
||||
dim3 grid(num_tokens);
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
|
||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
|
||||
uint32_t vec_size = 16 / sizeof(c_type);
|
||||
dim3 block(std::min(d / vec_size, 1024U));
|
||||
#if USE_ROCM
|
||||
sgl_hip::activation::act_and_mul_kernel<c_type, gelu>
|
||||
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
|
||||
#else
|
||||
flashinfer::activation::act_and_mul_kernel<c_type, gelu>
|
||||
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
|
||||
#endif
|
||||
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
#if USE_ROCM
|
||||
void gelu_quick(at::Tensor& out, const at::Tensor& input) {
|
||||
int d = input.size(-1);
|
||||
int64_t num_tokens = input.numel() / input.size(-1);
|
||||
dim3 grid(num_tokens);
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
|
||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
|
||||
uint32_t vec_size = 16 / sizeof(c_type);
|
||||
dim3 block(std::min(d / vec_size, 1024U));
|
||||
sgl_hip::activation::act_only_kernel<c_type, gelu_quick_act>
|
||||
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
|
||||
|
||||
return true;
|
||||
});
|
||||
}
|
||||
#endif
|
||||
171
sgl-kernel/csrc/elementwise/cast.cu
Normal file
171
sgl-kernel/csrc/elementwise/cast.cu
Normal file
@@ -0,0 +1,171 @@
|
||||
#include "pytorch_extension_utils.h"
|
||||
|
||||
template <typename T>
|
||||
struct ConvertToFP8 {
|
||||
static __device__ __nv_fp8_storage_t convert_to_fp8(T value) {
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvertToFP8<__nv_bfloat16> {
|
||||
static __device__ __nv_fp8_storage_t convert_to_fp8(__nv_bfloat16 value) {
|
||||
return __nv_cvt_bfloat16raw_to_fp8(value, __NV_SATFINITE, __NV_E4M3);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvertToFP8<half> {
|
||||
static __device__ __nv_fp8_storage_t convert_to_fp8(half value) {
|
||||
return __nv_cvt_halfraw_to_fp8(value, __NV_SATFINITE, __NV_E4M3);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ConvertFromFloat {
|
||||
static __device__ T convert_from_float(float value) {
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvertFromFloat<__nv_bfloat16> {
|
||||
static __device__ __nv_bfloat16 convert_from_float(float value) {
|
||||
return __float2bfloat16(value);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvertFromFloat<half> {
|
||||
static __device__ half convert_from_float(float value) {
|
||||
return __float2half(value);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__global__ void fused_downcast_kernel(
|
||||
const T* cache_k,
|
||||
const T* cache_v,
|
||||
const float* k_scale,
|
||||
const float* v_scale,
|
||||
__nv_fp8_storage_t* output_k,
|
||||
__nv_fp8_storage_t* output_v,
|
||||
const int input_sl,
|
||||
const int head,
|
||||
const int dim,
|
||||
const T max_fp8,
|
||||
const T min_fp8,
|
||||
const int64_t mult,
|
||||
const int64_t offset,
|
||||
const int64_t* loc) {
|
||||
// TODO: change name
|
||||
int token_idx = blockIdx.x;
|
||||
int thread_idx = threadIdx.x;
|
||||
int total_threads = blockDim.x;
|
||||
|
||||
T k_scale_val = ConvertFromFloat<T>::convert_from_float(k_scale[0]);
|
||||
T v_scale_val = ConvertFromFloat<T>::convert_from_float(v_scale[0]);
|
||||
|
||||
T k_scale_inv = static_cast<T>(1.f) / k_scale_val;
|
||||
T v_scale_inv = static_cast<T>(1.f) / v_scale_val;
|
||||
|
||||
auto clamp = [&](T val) { return val > max_fp8 ? max_fp8 : (min_fp8 > val ? min_fp8 : val); };
|
||||
|
||||
if (token_idx < input_sl) {
|
||||
int out_seq_idx = loc[token_idx];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = thread_idx; i < head * dim; i += total_threads) {
|
||||
int in_idx = token_idx * head * dim + i;
|
||||
int out_idx = (out_seq_idx * mult + offset) * head * dim + i;
|
||||
|
||||
T k_val = cache_k[in_idx] * k_scale_inv;
|
||||
k_val = clamp(k_val);
|
||||
output_k[out_idx] = ConvertToFP8<T>::convert_to_fp8(k_val);
|
||||
|
||||
T v_val = cache_v[in_idx] * v_scale_inv;
|
||||
v_val = clamp(v_val);
|
||||
output_v[out_idx] = ConvertToFP8<T>::convert_to_fp8(v_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void downcast_fp8_impl(
|
||||
at::Tensor& k,
|
||||
at::Tensor& v,
|
||||
at::Tensor& k_out,
|
||||
at::Tensor& v_out,
|
||||
at::Tensor& k_scale,
|
||||
at::Tensor& v_scale,
|
||||
at::Tensor& loc,
|
||||
int64_t mult,
|
||||
int64_t offset,
|
||||
cudaStream_t stream) {
|
||||
CHECK_INPUT(k);
|
||||
CHECK_INPUT(v);
|
||||
CHECK_INPUT(k_out);
|
||||
CHECK_INPUT(v_out);
|
||||
CHECK_INPUT(k_scale);
|
||||
CHECK_INPUT(v_scale);
|
||||
CHECK_INPUT(loc);
|
||||
|
||||
int64_t input_sl = k.size(0);
|
||||
int64_t head = k.size(1);
|
||||
int64_t dim = k.size(2);
|
||||
|
||||
dim3 grid(input_sl * head);
|
||||
int vec_size = 8;
|
||||
dim3 block(std::min(int(dim) / vec_size, 1024));
|
||||
|
||||
const T max_fp8 = static_cast<T>(448.0f);
|
||||
const T min_fp8 = static_cast<T>(-448.0f);
|
||||
|
||||
fused_downcast_kernel<T><<<grid, block, 0, stream>>>(
|
||||
static_cast<const T*>(k.data_ptr()),
|
||||
static_cast<const T*>(v.data_ptr()),
|
||||
static_cast<const float*>(k_scale.data_ptr()),
|
||||
static_cast<const float*>(v_scale.data_ptr()),
|
||||
static_cast<__nv_fp8_storage_t*>(k_out.data_ptr()),
|
||||
static_cast<__nv_fp8_storage_t*>(v_out.data_ptr()),
|
||||
input_sl,
|
||||
head,
|
||||
dim,
|
||||
max_fp8,
|
||||
min_fp8,
|
||||
mult,
|
||||
offset,
|
||||
static_cast<const int64_t*>(loc.data_ptr()));
|
||||
|
||||
cudaError_t status = cudaGetLastError();
|
||||
TORCH_CHECK(status == cudaSuccess, "Kernel launch failed: " + std::string(cudaGetErrorString(status)));
|
||||
}
|
||||
|
||||
void downcast_fp8(
|
||||
at::Tensor& k,
|
||||
at::Tensor& v,
|
||||
at::Tensor& k_out,
|
||||
at::Tensor& v_out,
|
||||
at::Tensor& k_scale,
|
||||
at::Tensor& v_scale,
|
||||
at::Tensor& loc,
|
||||
int64_t mult,
|
||||
int64_t offset,
|
||||
int64_t cuda_stream) {
|
||||
CHECK_INPUT(k);
|
||||
CHECK_INPUT(v);
|
||||
CHECK_INPUT(k_out);
|
||||
CHECK_INPUT(v_out);
|
||||
|
||||
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
|
||||
switch (k.scalar_type()) {
|
||||
case at::ScalarType::BFloat16:
|
||||
downcast_fp8_impl<__nv_bfloat16>(k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, stream);
|
||||
break;
|
||||
case at::ScalarType::Half:
|
||||
downcast_fp8_impl<__half>(k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, stream);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported input type for downcast_fp8. Expected bfloat16 or float16.");
|
||||
}
|
||||
}
|
||||
59
sgl-kernel/csrc/elementwise/fused_add_rms_norm_kernel.cu
Normal file
59
sgl-kernel/csrc/elementwise/fused_add_rms_norm_kernel.cu
Normal file
@@ -0,0 +1,59 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include <flashinfer/norm.cuh>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
using namespace flashinfer;
|
||||
|
||||
void sgl_fused_add_rmsnorm(
|
||||
torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps, bool enable_pdl) {
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(residual);
|
||||
CHECK_INPUT(weight);
|
||||
auto device = input.device();
|
||||
CHECK_EQ(residual.device(), device);
|
||||
CHECK_EQ(weight.device(), device);
|
||||
CHECK_DIM(2, input); // input: (batch_size, hidden_size)
|
||||
CHECK_DIM(2, residual); // residual: (batch_size, hidden_size)
|
||||
CHECK_DIM(1, weight); // weight: (hidden_size)
|
||||
CHECK_EQ(input.size(0), residual.size(0));
|
||||
CHECK_EQ(input.size(1), residual.size(1));
|
||||
CHECK_EQ(input.size(1), weight.size(0));
|
||||
unsigned int batch_size = input.size(0);
|
||||
unsigned int hidden_size = input.size(1);
|
||||
|
||||
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
|
||||
// support float16, bfloat16 and float32
|
||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
|
||||
cudaError_t status = norm::FusedAddRMSNorm(
|
||||
static_cast<c_type*>(input.data_ptr()),
|
||||
static_cast<c_type*>(residual.data_ptr()),
|
||||
static_cast<c_type*>(weight.data_ptr()),
|
||||
batch_size,
|
||||
hidden_size,
|
||||
input.stride(0),
|
||||
residual.stride(0),
|
||||
eps,
|
||||
enable_pdl,
|
||||
torch_current_stream);
|
||||
TORCH_CHECK(
|
||||
status == cudaSuccess, "FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
|
||||
return true;
|
||||
});
|
||||
}
|
||||
467
sgl-kernel/csrc/elementwise/pos_enc.cuh
Normal file
467
sgl-kernel/csrc/elementwise/pos_enc.cuh
Normal file
@@ -0,0 +1,467 @@
|
||||
/*
|
||||
* Copyright (c) 2023 by FlashInfer team.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef SGL_POS_ENC_CUH_
|
||||
#define SGL_POS_ENC_CUH_
|
||||
|
||||
#include <flashinfer/pos_enc.cuh> // upstream
|
||||
|
||||
namespace flashinfer {
|
||||
|
||||
namespace kv_buffer_saver {
|
||||
|
||||
template <typename DType, typename IdType, uint32_t vec_size>
|
||||
__device__ __forceinline__ void prepare(
|
||||
vec_t<float, vec_size>& v_vec,
|
||||
IdType& kv_cache_offset,
|
||||
DType* v,
|
||||
IdType* kv_cache_loc,
|
||||
uint32_t idx,
|
||||
uint32_t tx,
|
||||
uint32_t kv_head_idx,
|
||||
size_t v_stride_n,
|
||||
size_t v_stride_h) {
|
||||
kv_cache_offset = kv_cache_loc[idx];
|
||||
|
||||
DType* v_ptr = v + get_elem_offset_impl(idx, kv_head_idx, 0, v_stride_n, v_stride_h);
|
||||
v_vec.cast_load(v_ptr + tx * vec_size);
|
||||
}
|
||||
|
||||
template <typename DType, typename IdType, uint32_t vec_size>
|
||||
__device__ __forceinline__ void save(
|
||||
IdType& kv_cache_offset,
|
||||
vec_t<float, vec_size>& k_vec,
|
||||
vec_t<float, vec_size>& v_vec,
|
||||
DType* k_buffer,
|
||||
DType* v_buffer,
|
||||
uint32_t idx,
|
||||
uint32_t tx,
|
||||
uint32_t kv_head_idx,
|
||||
size_t k_buffer_stride_n,
|
||||
size_t k_buffer_stride_h,
|
||||
size_t v_buffer_stride_n,
|
||||
size_t v_buffer_stride_h) {
|
||||
DType* k_buffer_ptr =
|
||||
k_buffer + get_elem_offset_impl(kv_cache_offset, kv_head_idx, 0, k_buffer_stride_n, k_buffer_stride_h);
|
||||
DType* v_buffer_ptr =
|
||||
v_buffer + get_elem_offset_impl(kv_cache_offset, kv_head_idx, 0, v_buffer_stride_n, v_buffer_stride_h);
|
||||
k_vec.cast_store(k_buffer_ptr + tx * vec_size);
|
||||
v_vec.cast_store(v_buffer_ptr + tx * vec_size);
|
||||
}
|
||||
|
||||
} // namespace kv_buffer_saver
|
||||
|
||||
template <
|
||||
bool save_kv_cache,
|
||||
bool interleave,
|
||||
uint32_t head_dim,
|
||||
uint32_t vec_size,
|
||||
uint32_t bdx,
|
||||
typename DType,
|
||||
typename IdType>
|
||||
__global__ void BatchQKApplyRotaryPosIdsCosSinCacheEnhancedHeadParallelismKernel(
|
||||
DType* q,
|
||||
DType* k,
|
||||
DType* v,
|
||||
DType* q_rope,
|
||||
DType* k_rope,
|
||||
DType* k_buffer,
|
||||
DType* v_buffer,
|
||||
float* __restrict__ cos_sin_cache,
|
||||
IdType* __restrict__ pos_ids,
|
||||
uint32_t nnz,
|
||||
uint32_t num_qo_heads,
|
||||
uint32_t num_kv_heads,
|
||||
uint32_t rotary_dim,
|
||||
size_t q_stride_n,
|
||||
size_t q_stride_h,
|
||||
size_t k_stride_n,
|
||||
size_t k_stride_h,
|
||||
size_t v_stride_n,
|
||||
size_t v_stride_h,
|
||||
size_t q_rope_stride_n,
|
||||
size_t q_rope_stride_h,
|
||||
size_t k_rope_stride_n,
|
||||
size_t k_rope_stride_h,
|
||||
size_t k_buffer_stride_n,
|
||||
size_t k_buffer_stride_h,
|
||||
size_t v_buffer_stride_n,
|
||||
size_t v_buffer_stride_h,
|
||||
IdType* __restrict__ kv_cache_loc) {
|
||||
uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y;
|
||||
uint32_t by = blockIdx.y;
|
||||
const uint32_t bdy = blockDim.y;
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.wait;");
|
||||
#endif
|
||||
|
||||
vec_t<float, vec_size> cos, sin;
|
||||
if (bx * bdy + ty < nnz) {
|
||||
const uint32_t idx = bx * bdy + ty;
|
||||
const IdType pos = pos_ids[idx];
|
||||
|
||||
const int half_rotary_dim = rotary_dim / 2;
|
||||
|
||||
// 1. if interleave:
|
||||
// - cos = cos_sin_cache[pos_id][tx * vec_size // 2]
|
||||
// - sin = cos_sin_cache[pos_id][(rot_dim // 2) + tx * vec_size // 2]
|
||||
// 2. if not interleave
|
||||
// - cos = cos_cache[pos_id][(tx * vec_size) % (rot_dim // 2)]
|
||||
// - sin = sin_cache[pos_id][(rot_dim // 2) + (tx * vec_size) % (rot_dim // 2)]
|
||||
if (tx * vec_size < rotary_dim) {
|
||||
int sin_offset = rotary_dim / 2;
|
||||
int vec_idx;
|
||||
if constexpr (interleave) {
|
||||
vec_idx = (tx * vec_size) / 2; // Force integer division
|
||||
} else {
|
||||
vec_idx = (tx * vec_size) % half_rotary_dim; // Use half_rotary_dim
|
||||
}
|
||||
cos.load(cos_sin_cache + (pos * rotary_dim) + vec_idx);
|
||||
sin.load(cos_sin_cache + (pos * rotary_dim) + (sin_offset + vec_idx));
|
||||
}
|
||||
|
||||
if (by < num_qo_heads) {
|
||||
uint32_t qo_head_idx = by;
|
||||
DType* q_ptr = q + get_elem_offset_impl(idx, qo_head_idx, 0, q_stride_n, q_stride_h);
|
||||
DType* q_rope_ptr = q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h);
|
||||
vec_t<float, vec_size> q_vec;
|
||||
if constexpr (interleave) {
|
||||
q_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
|
||||
} else {
|
||||
q_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
|
||||
}
|
||||
q_vec.cast_store(q_rope_ptr + tx * vec_size);
|
||||
} else {
|
||||
uint32_t kv_head_idx = by - num_qo_heads;
|
||||
DType* k_ptr = k + get_elem_offset_impl(idx, kv_head_idx, 0, k_stride_n, k_stride_h);
|
||||
|
||||
DType* k_rope_ptr = k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h);
|
||||
|
||||
vec_t<float, vec_size> v_vec;
|
||||
IdType kv_cache_offset;
|
||||
if constexpr (save_kv_cache) {
|
||||
kv_buffer_saver::prepare<DType, IdType, vec_size>(
|
||||
v_vec, kv_cache_offset, v, kv_cache_loc, idx, tx, kv_head_idx, v_stride_n, v_stride_h);
|
||||
}
|
||||
|
||||
vec_t<float, vec_size> k_vec;
|
||||
if constexpr (interleave) {
|
||||
k_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
|
||||
} else {
|
||||
k_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
|
||||
}
|
||||
k_vec.cast_store(k_rope_ptr + tx * vec_size);
|
||||
|
||||
if constexpr (save_kv_cache) {
|
||||
kv_buffer_saver::save<DType, IdType, vec_size>(
|
||||
kv_cache_offset,
|
||||
k_vec,
|
||||
v_vec,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
idx,
|
||||
tx,
|
||||
kv_head_idx,
|
||||
k_buffer_stride_n,
|
||||
k_buffer_stride_h,
|
||||
v_buffer_stride_n,
|
||||
v_buffer_stride_h);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <
|
||||
bool save_kv_cache,
|
||||
bool interleave,
|
||||
uint32_t head_dim,
|
||||
uint32_t vec_size,
|
||||
uint32_t bdx,
|
||||
typename DType,
|
||||
typename IdType>
|
||||
__global__ void BatchQKApplyRotaryPosIdsCosSinCacheEnhancedKernel(
|
||||
DType* q,
|
||||
DType* k,
|
||||
DType* v,
|
||||
DType* q_rope,
|
||||
DType* k_rope,
|
||||
DType* k_buffer,
|
||||
DType* v_buffer,
|
||||
float* __restrict__ cos_sin_cache,
|
||||
IdType* __restrict__ pos_ids,
|
||||
uint32_t nnz,
|
||||
uint32_t num_qo_heads,
|
||||
uint32_t num_kv_heads,
|
||||
uint32_t rotary_dim,
|
||||
size_t q_stride_n,
|
||||
size_t q_stride_h,
|
||||
size_t k_stride_n,
|
||||
size_t k_stride_h,
|
||||
size_t v_stride_n,
|
||||
size_t v_stride_h,
|
||||
size_t q_rope_stride_n,
|
||||
size_t q_rope_stride_h,
|
||||
size_t k_rope_stride_n,
|
||||
size_t k_rope_stride_h,
|
||||
size_t k_buffer_stride_n,
|
||||
size_t k_buffer_stride_h,
|
||||
size_t v_buffer_stride_n,
|
||||
size_t v_buffer_stride_h,
|
||||
IdType* __restrict__ kv_cache_loc) {
|
||||
uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y;
|
||||
const uint32_t bdy = blockDim.y;
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.wait;");
|
||||
#endif
|
||||
|
||||
vec_t<float, vec_size> cos, sin;
|
||||
if (bx * bdy + ty < nnz) {
|
||||
const uint32_t idx = bx * bdy + ty;
|
||||
const IdType pos = pos_ids[idx];
|
||||
const int half_rotary_dim = rotary_dim / 2;
|
||||
|
||||
// 1. if interleave:
|
||||
// - cos = cos_sin_cache[pos_id][tx * vec_size // 2]
|
||||
// - sin = cos_sin_cache[pos_id][(rot_dim // 2) + tx * vec_size // 2]
|
||||
// 2. if not interleave
|
||||
// - cos = cos_cache[pos_id][(tx * vec_size) % (rot_dim // 2)]
|
||||
// - sin = sin_cache[pos_id][(rot_dim // 2) + (tx * vec_size) % (rot_dim // 2)]
|
||||
if (tx * vec_size < rotary_dim) {
|
||||
int sin_offset = rotary_dim / 2;
|
||||
int vec_idx;
|
||||
if constexpr (interleave) {
|
||||
vec_idx = (tx * vec_size) / 2; // Force integer division
|
||||
} else {
|
||||
vec_idx = (tx * vec_size) % half_rotary_dim; // Use half_rotary_dim
|
||||
}
|
||||
cos.load(cos_sin_cache + (pos * rotary_dim) + vec_idx);
|
||||
sin.load(cos_sin_cache + (pos * rotary_dim) + (sin_offset + vec_idx));
|
||||
}
|
||||
|
||||
// not to unroll the loop, because num head might be large and might lead to worse performance
|
||||
#pragma unroll 1
|
||||
for (uint32_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) {
|
||||
DType* q_ptr = q + get_elem_offset_impl(idx, qo_head_idx, 0, q_stride_n, q_stride_h);
|
||||
DType* q_rope_ptr = q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h);
|
||||
vec_t<float, vec_size> q_vec;
|
||||
if constexpr (interleave) {
|
||||
q_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
|
||||
} else {
|
||||
q_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
|
||||
}
|
||||
q_vec.cast_store(q_rope_ptr + tx * vec_size);
|
||||
}
|
||||
|
||||
#pragma unroll 1
|
||||
for (uint32_t kv_head_idx = 0; kv_head_idx < num_kv_heads; ++kv_head_idx) {
|
||||
DType* k_ptr = k + get_elem_offset_impl(idx, kv_head_idx, 0, k_stride_n, k_stride_h);
|
||||
|
||||
DType* k_rope_ptr = k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h);
|
||||
|
||||
vec_t<float, vec_size> v_vec;
|
||||
IdType kv_cache_offset;
|
||||
if constexpr (save_kv_cache) {
|
||||
kv_buffer_saver::prepare<DType, IdType, vec_size>(
|
||||
v_vec, kv_cache_offset, v, kv_cache_loc, idx, tx, kv_head_idx, v_stride_n, v_stride_h);
|
||||
}
|
||||
|
||||
vec_t<float, vec_size> k_vec;
|
||||
if constexpr (interleave) {
|
||||
k_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
|
||||
} else {
|
||||
k_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
|
||||
}
|
||||
k_vec.cast_store(k_rope_ptr + tx * vec_size);
|
||||
|
||||
if constexpr (save_kv_cache) {
|
||||
kv_buffer_saver::save<DType, IdType, vec_size>(
|
||||
kv_cache_offset,
|
||||
k_vec,
|
||||
v_vec,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
idx,
|
||||
tx,
|
||||
kv_head_idx,
|
||||
k_buffer_stride_n,
|
||||
k_buffer_stride_h,
|
||||
v_buffer_stride_n,
|
||||
v_buffer_stride_h);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
#endif
|
||||
}
|
||||
|
||||
#define DISPATCH_SAVE_KV_CACHE(save_kv_cache, SAVE_KV_CACHE, ...) \
|
||||
if (save_kv_cache) { \
|
||||
const bool SAVE_KV_CACHE = true; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
const bool SAVE_KV_CACHE = false; \
|
||||
__VA_ARGS__ \
|
||||
}
|
||||
|
||||
template <typename DType, typename IdType>
|
||||
cudaError_t BatchQKApplyRotaryPosIdsCosSinCacheEnhanced(
|
||||
DType* q,
|
||||
DType* k,
|
||||
DType* v,
|
||||
DType* q_rope,
|
||||
DType* k_rope,
|
||||
DType* k_buffer,
|
||||
DType* v_buffer,
|
||||
float* cos_sin_cache,
|
||||
IdType* pos_ids,
|
||||
uint32_t nnz,
|
||||
uint32_t num_qo_heads,
|
||||
uint32_t num_kv_heads,
|
||||
uint32_t rotary_dim,
|
||||
uint32_t head_dim,
|
||||
size_t q_stride_n,
|
||||
size_t q_stride_h,
|
||||
size_t k_stride_n,
|
||||
size_t k_stride_h,
|
||||
size_t v_stride_n,
|
||||
size_t v_stride_h,
|
||||
size_t q_rope_stride_n,
|
||||
size_t q_rope_stride_h,
|
||||
size_t k_rope_stride_n,
|
||||
size_t k_rope_stride_h,
|
||||
size_t k_buffer_stride_n,
|
||||
size_t k_buffer_stride_h,
|
||||
size_t v_buffer_stride_n,
|
||||
size_t v_buffer_stride_h,
|
||||
IdType* kv_cache_loc,
|
||||
bool interleave,
|
||||
bool save_kv_cache,
|
||||
bool enable_pdl,
|
||||
cudaStream_t stream = nullptr) {
|
||||
int dev_id = 0;
|
||||
int num_sms = 0;
|
||||
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
|
||||
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id));
|
||||
|
||||
#define LAUNCH_KERNEL_RAW(kernel_name) \
|
||||
do { \
|
||||
cudaLaunchConfig_t config = {}; \
|
||||
config.gridDim = nblks; \
|
||||
config.blockDim = nthrs; \
|
||||
config.dynamicSmemBytes = 0; \
|
||||
config.stream = stream; \
|
||||
cudaLaunchAttribute attrs[1] = {}; \
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; \
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; \
|
||||
config.numAttrs = 1; \
|
||||
config.attrs = attrs; \
|
||||
\
|
||||
FLASHINFER_CUDA_CALL(cudaLaunchKernelEx( \
|
||||
&config, \
|
||||
kernel_name, \
|
||||
q, \
|
||||
k, \
|
||||
v, \
|
||||
q_rope, \
|
||||
k_rope, \
|
||||
k_buffer, \
|
||||
v_buffer, \
|
||||
cos_sin_cache, \
|
||||
pos_ids, \
|
||||
nnz, \
|
||||
num_qo_heads, \
|
||||
num_kv_heads, \
|
||||
rotary_dim, \
|
||||
q_stride_n, \
|
||||
q_stride_h, \
|
||||
k_stride_n, \
|
||||
k_stride_h, \
|
||||
v_stride_n, \
|
||||
v_stride_h, \
|
||||
q_rope_stride_n, \
|
||||
q_rope_stride_h, \
|
||||
k_rope_stride_n, \
|
||||
k_rope_stride_h, \
|
||||
k_buffer_stride_n, \
|
||||
k_buffer_stride_h, \
|
||||
v_buffer_stride_n, \
|
||||
v_buffer_stride_h, \
|
||||
kv_cache_loc)); \
|
||||
} while (0)
|
||||
|
||||
DISPATCH_SAVE_KV_CACHE(save_kv_cache, SAVE_KV_CACHE, {
|
||||
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
|
||||
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
|
||||
// operate on 16 Bytes at a time
|
||||
constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32);
|
||||
// how many threads needed per head_dim
|
||||
constexpr uint32_t bdx = HEAD_DIM / vec_size;
|
||||
// how many threads needed per block
|
||||
uint32_t num_threads = std::max(128U, bdx);
|
||||
// how many tokens can we process in a block
|
||||
uint32_t bdy = num_threads / bdx;
|
||||
// how many blocks needed to process all tokens
|
||||
uint32_t nblks_x = (nnz + bdy - 1) / bdy;
|
||||
|
||||
auto kernel_0 = BatchQKApplyRotaryPosIdsCosSinCacheEnhancedKernel<
|
||||
SAVE_KV_CACHE,
|
||||
INTERLEAVE,
|
||||
HEAD_DIM,
|
||||
vec_size,
|
||||
bdx,
|
||||
DType,
|
||||
IdType>;
|
||||
|
||||
int num_blocks_per_sm_0 = 0;
|
||||
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&num_blocks_per_sm_0, kernel_0, num_threads, /*smem_size=*/0));
|
||||
uint32_t num_ctas_0 = num_blocks_per_sm_0 * num_sms;
|
||||
|
||||
if ((nnz + bdy - 1) / bdy >= num_ctas_0) {
|
||||
dim3 nblks(nblks_x);
|
||||
dim3 nthrs(bdx, bdy);
|
||||
LAUNCH_KERNEL_RAW(kernel_0);
|
||||
} else {
|
||||
dim3 nblks(nblks_x, num_qo_heads + num_kv_heads);
|
||||
dim3 nthrs(bdx, bdy);
|
||||
auto kernel_1 = BatchQKApplyRotaryPosIdsCosSinCacheEnhancedHeadParallelismKernel<
|
||||
SAVE_KV_CACHE,
|
||||
INTERLEAVE,
|
||||
HEAD_DIM,
|
||||
vec_size,
|
||||
bdx,
|
||||
DType,
|
||||
IdType>;
|
||||
LAUNCH_KERNEL_RAW(kernel_1);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
#undef LAUNCH_KERNEL_RAW
|
||||
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
} // namespace flashinfer
|
||||
|
||||
#endif // SGL_POS_ENC_CUH_
|
||||
164
sgl-kernel/csrc/elementwise/rope.cu
Normal file
164
sgl-kernel/csrc/elementwise/rope.cu
Normal file
@@ -0,0 +1,164 @@
|
||||
/*
|
||||
* Copyright (c) 2024 by FlashInfer team.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "pos_enc.cuh"
|
||||
#include "pytorch_extension_utils.h"
|
||||
|
||||
using namespace flashinfer;
|
||||
|
||||
void apply_rope_pos_ids_cos_sin_cache(
|
||||
at::Tensor q,
|
||||
at::Tensor k,
|
||||
at::Tensor q_rope,
|
||||
at::Tensor k_rope,
|
||||
at::Tensor cos_sin_cache,
|
||||
at::Tensor pos_ids,
|
||||
bool interleave,
|
||||
bool enable_pdl,
|
||||
int64_t cuda_stream,
|
||||
const std::optional<at::Tensor>& v,
|
||||
const std::optional<at::Tensor>& k_buffer,
|
||||
const std::optional<at::Tensor>& v_buffer,
|
||||
const std::optional<at::Tensor>& kv_cache_loc) {
|
||||
CHECK_LAST_DIM_CONTIGUOUS(q);
|
||||
CHECK_LAST_DIM_CONTIGUOUS(k);
|
||||
|
||||
const bool save_kv_cache = v.has_value();
|
||||
if (save_kv_cache) {
|
||||
TORCH_CHECK(v.has_value());
|
||||
TORCH_CHECK(k_buffer.has_value());
|
||||
TORCH_CHECK(v_buffer.has_value());
|
||||
TORCH_CHECK(kv_cache_loc.has_value());
|
||||
CHECK_LAST_DIM_CONTIGUOUS(v.value());
|
||||
CHECK_LAST_DIM_CONTIGUOUS(k_buffer.value());
|
||||
CHECK_LAST_DIM_CONTIGUOUS(v_buffer.value());
|
||||
CHECK_DIM(3, k_buffer.value()); // k_buffer: (nnz, H_K, D)
|
||||
CHECK_DIM(3, v_buffer.value()); // v_buffer: (nnz, H_V, D)
|
||||
CHECK_DIM(3, v.value()); // v: (nnz, H_V, D)
|
||||
CHECK_DIM(1, kv_cache_loc.value()); // v: (n)
|
||||
CHECK_INPUT(kv_cache_loc.value());
|
||||
}
|
||||
size_t k_buffer_stride_n = save_kv_cache ? k_buffer->stride(0) : 0;
|
||||
size_t k_buffer_stride_h = save_kv_cache ? k_buffer->stride(1) : 0;
|
||||
size_t v_buffer_stride_n = save_kv_cache ? v_buffer->stride(0) : 0;
|
||||
size_t v_buffer_stride_h = save_kv_cache ? v_buffer->stride(1) : 0;
|
||||
size_t v_stride_n = save_kv_cache ? v->stride(0) : 0;
|
||||
size_t v_stride_h = save_kv_cache ? v->stride(1) : 0;
|
||||
auto kv_cache_loc_ptr = save_kv_cache ? static_cast<int64_t*>(kv_cache_loc->data_ptr()) : nullptr;
|
||||
|
||||
CHECK_INPUT(cos_sin_cache);
|
||||
CHECK_INPUT(pos_ids);
|
||||
auto device = q.device();
|
||||
CHECK_EQ(k.device(), device);
|
||||
CHECK_EQ(cos_sin_cache.device(), device);
|
||||
CHECK_EQ(pos_ids.device(), device);
|
||||
CHECK_DIM(3, q); // q: (nnz, H_Q, D)
|
||||
CHECK_DIM(3, k); // k: (nnz, H_K, D)
|
||||
|
||||
// cos_sin_cache: (max_seq_len, R)
|
||||
// First half of R is cos, second half is sin
|
||||
CHECK_DIM(2, cos_sin_cache);
|
||||
CHECK_EQ(q.size(0), k.size(0));
|
||||
CHECK_EQ(q.size(2), k.size(2));
|
||||
unsigned int rotary_dim = cos_sin_cache.size(1);
|
||||
unsigned int num_qo_heads = q.size(1);
|
||||
unsigned int num_kv_heads = k.size(1);
|
||||
unsigned int head_dim = q.size(2);
|
||||
unsigned int nnz = q.size(0);
|
||||
size_t q_stride_n = q.stride(0);
|
||||
size_t q_stride_h = q.stride(1);
|
||||
size_t k_stride_n = k.stride(0);
|
||||
size_t k_stride_h = k.stride(1);
|
||||
|
||||
size_t q_rope_stride_n = q_rope.stride(0);
|
||||
size_t q_rope_stride_h = q_rope.stride(1);
|
||||
size_t k_rope_stride_n = k_rope.stride(0);
|
||||
size_t k_rope_stride_h = k_rope.stride(1);
|
||||
|
||||
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
|
||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
|
||||
// TODO temporarily only use `BatchQKApplyRotaryPosIdsCosSinCacheEnhanced` when save_kv_cache
|
||||
// to avoid changing original code path; but this branch is feature-complete and should switch to this later
|
||||
if (save_kv_cache) {
|
||||
cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCacheEnhanced(
|
||||
static_cast<c_type*>(q.data_ptr()),
|
||||
static_cast<c_type*>(k.data_ptr()),
|
||||
save_kv_cache ? static_cast<c_type*>(v->data_ptr()) : nullptr,
|
||||
static_cast<c_type*>(q_rope.data_ptr()),
|
||||
static_cast<c_type*>(k_rope.data_ptr()),
|
||||
save_kv_cache ? static_cast<c_type*>(k_buffer->data_ptr()) : nullptr,
|
||||
save_kv_cache ? static_cast<c_type*>(v_buffer->data_ptr()) : nullptr,
|
||||
static_cast<float*>(cos_sin_cache.data_ptr()),
|
||||
static_cast<int64_t*>(pos_ids.data_ptr()),
|
||||
nnz,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
rotary_dim,
|
||||
head_dim,
|
||||
q_stride_n,
|
||||
q_stride_h,
|
||||
k_stride_n,
|
||||
k_stride_h,
|
||||
v_stride_n,
|
||||
v_stride_h,
|
||||
q_rope_stride_n,
|
||||
q_rope_stride_h,
|
||||
k_rope_stride_n,
|
||||
k_rope_stride_h,
|
||||
k_buffer_stride_n,
|
||||
k_buffer_stride_h,
|
||||
v_buffer_stride_n,
|
||||
v_buffer_stride_h,
|
||||
kv_cache_loc_ptr,
|
||||
interleave,
|
||||
save_kv_cache,
|
||||
enable_pdl,
|
||||
stream);
|
||||
TORCH_CHECK(
|
||||
status == cudaSuccess,
|
||||
"BatchQKApplyRotaryPosIdsCosSinCacheEnhanced failed with error code " +
|
||||
std::string(cudaGetErrorString(status)));
|
||||
} else {
|
||||
TORCH_CHECK(!enable_pdl);
|
||||
cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCache(
|
||||
static_cast<c_type*>(q.data_ptr()),
|
||||
static_cast<c_type*>(k.data_ptr()),
|
||||
static_cast<c_type*>(q_rope.data_ptr()),
|
||||
static_cast<c_type*>(k_rope.data_ptr()),
|
||||
static_cast<float*>(cos_sin_cache.data_ptr()),
|
||||
static_cast<int64_t*>(pos_ids.data_ptr()),
|
||||
nnz,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
rotary_dim,
|
||||
head_dim,
|
||||
q_stride_n,
|
||||
q_stride_h,
|
||||
k_stride_n,
|
||||
k_stride_h,
|
||||
q_rope_stride_n,
|
||||
q_rope_stride_h,
|
||||
k_rope_stride_n,
|
||||
k_rope_stride_h,
|
||||
interleave,
|
||||
stream);
|
||||
TORCH_CHECK(
|
||||
status == cudaSuccess,
|
||||
"BatchQKApplyRotaryPosIdsCosSinCache failed with error code " + std::string(cudaGetErrorString(status)));
|
||||
}
|
||||
return true;
|
||||
});
|
||||
}
|
||||
Reference in New Issue
Block a user