Fuse writing KV buffer into rope kernel (part 1: sgl-kernel) (#9077)

This commit is contained in:
fzyzcjy
2025-08-12 16:46:40 +08:00
committed by GitHub
parent fcc11e5ed5
commit 9aea255522
11 changed files with 1152 additions and 194 deletions

View File

@@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <flashinfer/pos_enc.cuh>
#include "pos_enc.cuh"
#include "pytorch_extension_utils.h"
using namespace flashinfer;
@@ -27,9 +27,37 @@ void apply_rope_pos_ids_cos_sin_cache(
at::Tensor cos_sin_cache,
at::Tensor pos_ids,
bool interleave,
int64_t cuda_stream) {
int64_t cuda_stream,
const std::optional<at::Tensor>& v,
const std::optional<at::Tensor>& k_buffer,
const std::optional<at::Tensor>& v_buffer,
const std::optional<at::Tensor>& kv_cache_loc) {
CHECK_LAST_DIM_CONTIGUOUS(q);
CHECK_LAST_DIM_CONTIGUOUS(k);
const bool save_kv_cache = v.has_value();
if (save_kv_cache) {
TORCH_CHECK(v.has_value());
TORCH_CHECK(k_buffer.has_value());
TORCH_CHECK(v_buffer.has_value());
TORCH_CHECK(kv_cache_loc.has_value());
CHECK_LAST_DIM_CONTIGUOUS(v.value());
CHECK_LAST_DIM_CONTIGUOUS(k_buffer.value());
CHECK_LAST_DIM_CONTIGUOUS(v_buffer.value());
CHECK_DIM(3, k_buffer.value()); // k_buffer: (nnz, H_K, D)
CHECK_DIM(3, v_buffer.value()); // v_buffer: (nnz, H_V, D)
CHECK_DIM(3, v.value()); // v: (nnz, H_V, D)
CHECK_DIM(1, kv_cache_loc.value()); // v: (n)
CHECK_INPUT(kv_cache_loc.value());
}
size_t k_buffer_stride_n = save_kv_cache ? k_buffer->stride(0) : 0;
size_t k_buffer_stride_h = save_kv_cache ? k_buffer->stride(1) : 0;
size_t v_buffer_stride_n = save_kv_cache ? v_buffer->stride(0) : 0;
size_t v_buffer_stride_h = save_kv_cache ? v_buffer->stride(1) : 0;
size_t v_stride_n = save_kv_cache ? v->stride(0) : 0;
size_t v_stride_h = save_kv_cache ? v->stride(1) : 0;
auto kv_cache_loc_ptr = save_kv_cache ? static_cast<int64_t*>(kv_cache_loc->data_ptr()) : nullptr;
CHECK_INPUT(cos_sin_cache);
CHECK_INPUT(pos_ids);
auto device = q.device();
@@ -38,6 +66,7 @@ void apply_rope_pos_ids_cos_sin_cache(
CHECK_EQ(pos_ids.device(), device);
CHECK_DIM(3, q); // q: (nnz, H_Q, D)
CHECK_DIM(3, k); // k: (nnz, H_K, D)
// cos_sin_cache: (max_seq_len, R)
// First half of R is cos, second half is sin
CHECK_DIM(2, cos_sin_cache);
@@ -52,6 +81,7 @@ void apply_rope_pos_ids_cos_sin_cache(
size_t q_stride_h = q.stride(1);
size_t k_stride_n = k.stride(0);
size_t k_stride_h = k.stride(1);
size_t q_rope_stride_n = q_rope.stride(0);
size_t q_rope_stride_h = q_rope.stride(1);
size_t k_rope_stride_n = k_rope.stride(0);
@@ -59,31 +89,73 @@ void apply_rope_pos_ids_cos_sin_cache(
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCache(
static_cast<c_type*>(q.data_ptr()),
static_cast<c_type*>(k.data_ptr()),
static_cast<c_type*>(q_rope.data_ptr()),
static_cast<c_type*>(k_rope.data_ptr()),
static_cast<float*>(cos_sin_cache.data_ptr()),
static_cast<int64_t*>(pos_ids.data_ptr()),
nnz,
num_qo_heads,
num_kv_heads,
rotary_dim,
head_dim,
q_stride_n,
q_stride_h,
k_stride_n,
k_stride_h,
q_rope_stride_n,
q_rope_stride_h,
k_rope_stride_n,
k_rope_stride_h,
interleave,
stream);
TORCH_CHECK(
status == cudaSuccess,
"BatchQKApplyRotaryPosIdsCosSinCache failed with error code " + std::string(cudaGetErrorString(status)));
// TODO temporarily only use `BatchQKApplyRotaryPosIdsCosSinCacheEnhanced` when save_kv_cache
// to avoid changing original code path; but this branch is feature-complete and should switch to this later
if (save_kv_cache) {
cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCacheEnhanced(
static_cast<c_type*>(q.data_ptr()),
static_cast<c_type*>(k.data_ptr()),
save_kv_cache ? static_cast<c_type*>(v->data_ptr()) : nullptr,
static_cast<c_type*>(q_rope.data_ptr()),
static_cast<c_type*>(k_rope.data_ptr()),
save_kv_cache ? static_cast<c_type*>(k_buffer->data_ptr()) : nullptr,
save_kv_cache ? static_cast<c_type*>(v_buffer->data_ptr()) : nullptr,
static_cast<float*>(cos_sin_cache.data_ptr()),
static_cast<int64_t*>(pos_ids.data_ptr()),
nnz,
num_qo_heads,
num_kv_heads,
rotary_dim,
head_dim,
q_stride_n,
q_stride_h,
k_stride_n,
k_stride_h,
v_stride_n,
v_stride_h,
q_rope_stride_n,
q_rope_stride_h,
k_rope_stride_n,
k_rope_stride_h,
k_buffer_stride_n,
k_buffer_stride_h,
v_buffer_stride_n,
v_buffer_stride_h,
kv_cache_loc_ptr,
interleave,
save_kv_cache,
stream);
TORCH_CHECK(
status == cudaSuccess,
"BatchQKApplyRotaryPosIdsCosSinCacheEnhanced failed with error code " +
std::string(cudaGetErrorString(status)));
} else {
cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCache(
static_cast<c_type*>(q.data_ptr()),
static_cast<c_type*>(k.data_ptr()),
static_cast<c_type*>(q_rope.data_ptr()),
static_cast<c_type*>(k_rope.data_ptr()),
static_cast<float*>(cos_sin_cache.data_ptr()),
static_cast<int64_t*>(pos_ids.data_ptr()),
nnz,
num_qo_heads,
num_kv_heads,
rotary_dim,
head_dim,
q_stride_n,
q_stride_h,
k_stride_n,
k_stride_h,
q_rope_stride_n,
q_rope_stride_h,
k_rope_stride_n,
k_rope_stride_h,
interleave,
stream);
TORCH_CHECK(
status == cudaSuccess,
"BatchQKApplyRotaryPosIdsCosSinCache failed with error code " + std::string(cudaGetErrorString(status)));
}
return true;
});
}