diff --git a/python/sglang/srt/bench_utils.py b/python/sglang/srt/bench_utils.py new file mode 100644 index 000000000..e9f7fcbb4 --- /dev/null +++ b/python/sglang/srt/bench_utils.py @@ -0,0 +1,137 @@ +import os +import sys +from contextlib import nullcontext + +import torch + + +# NOTE copied and modified from DeepGEMM +class suppress_stdout_stderr: + def __enter__(self): + self.outnull_file = open(os.devnull, "w") + self.errnull_file = open(os.devnull, "w") + + self.old_stdout_fileno_undup = sys.stdout.fileno() + self.old_stderr_fileno_undup = sys.stderr.fileno() + + self.old_stdout_fileno = os.dup(sys.stdout.fileno()) + self.old_stderr_fileno = os.dup(sys.stderr.fileno()) + + self.old_stdout = sys.stdout + self.old_stderr = sys.stderr + + os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup) + os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup) + + sys.stdout = self.outnull_file + sys.stderr = self.errnull_file + return self + + def __exit__(self, *_): + sys.stdout = self.old_stdout + sys.stderr = self.old_stderr + + os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup) + os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) + + os.close(self.old_stdout_fileno) + os.close(self.old_stderr_fileno) + + self.outnull_file.close() + self.errnull_file.close() + + +# NOTE copied and modified from DeepGEMM +def bench_kineto( + fn, + kernel_names, + num_tests: int = 30, + suppress_kineto_output: bool = False, + trace_path: str = None, + flush_l2: bool = True, + with_multiple_kernels: bool = False, +): + # Conflict with Nsight Systems + using_nsys = int(os.environ.get("SGLANG_NSYS_PROFILING", 0)) + + # By default, flush L2 with an excessive 8GB memset to give the GPU some (literal) chill time without full idle + flush_l2_size = int(8e9 // 4) + + # For some auto-tuning kernels with prints + fn() + + # Profile + suppress = ( + suppress_stdout_stderr + if suppress_kineto_output and not using_nsys + else nullcontext + ) + with suppress(): + schedule = ( + torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) + if not using_nsys + else None + ) + profiler = ( + torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule + ) + if not using_nsys + else nullcontext() + ) + with profiler: + for i in range(2): + for _ in range(num_tests): + if flush_l2: + torch.empty( + flush_l2_size, dtype=torch.int, device="cuda" + ).zero_() + fn() + + if not using_nsys: + profiler.step() + + # Return 1 if using Nsight Systems + if using_nsys: + return 1 + + # Parse the profiling table + assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) + is_tuple = isinstance(kernel_names, tuple) + prof_lines = ( + profiler.key_averages() + .table(sort_by="cuda_time_total", max_name_column_width=100) + .split("\n") + ) + kernel_names = (kernel_names,) if isinstance(kernel_names, str) else kernel_names + assert all([isinstance(name, str) for name in kernel_names]) + if not with_multiple_kernels: + for name in kernel_names: + assert ( + sum([name in line for line in prof_lines]) == 1 + ), f"Errors of the kernel {name} in the profiling table (table: {prof_lines})" + + # Save chrome traces + if trace_path is not None: + profiler.export_chrome_trace(trace_path) + + # Return average kernel times + units = {"ms": 1e3, "us": 1e6} + kernel_times = [] + for name in kernel_names: + total_time = 0 + total_num = 0 + for line in prof_lines: + if name in line: + time_str = line.split()[-2] + num_str = line.split()[-1] + for unit, scale in units.items(): + if unit in time_str: + total_time += ( + float(time_str.replace(unit, "")) / scale * int(num_str) + ) + total_num += int(num_str) + break + kernel_times.append(total_time / total_num) + + return tuple(kernel_times) if is_tuple else kernel_times[0] diff --git a/sgl-kernel/benchmark/bench_rotary_embedding.py b/sgl-kernel/benchmark/bench_rotary_embedding.py new file mode 100644 index 000000000..b4e0f5e0b --- /dev/null +++ b/sgl-kernel/benchmark/bench_rotary_embedding.py @@ -0,0 +1,96 @@ +import itertools + +import torch +import triton +from sgl_kernel import FusedSetKVBufferArg +from sgl_kernel.testing.rotary_embedding import ( + FlashInferRotaryEmbedding, + MHATokenToKVPool, + RotaryEmbedding, + create_inputs, +) + +from sglang.srt.bench_utils import bench_kineto + +configs = [ + (batch_size, seq_len, save_kv_cache) + for batch_size, seq_len in ( + (1, 1), + (32, 1), + (128, 1), + (512, 1), + (2, 512), + (4, 4096), + ) + for save_kv_cache in (False, True) +] + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "seq_len", "save_kv_cache"], + x_vals=configs, + line_arg="provider", + line_vals=["sglang"], + line_names=["SGL Kernel"], + styles=[("green", "-")], + ylabel="us", + plot_name="bench_rotary_embedding", + args={}, + ) +) +def benchmark(batch_size, seq_len, save_kv_cache, provider): + device = torch.device("cuda") + + num_q_heads = 32 + num_kv_heads = 8 + head_size = 64 + dtype = torch.bfloat16 + + config = dict( + head_size=head_size, + rotary_dim=64, + max_position_embeddings=4096, + base=8000, + is_neox_style=True, + dtype=dtype, + ) + rope_flashinfer = FlashInferRotaryEmbedding(**config).to(device) + pool_flashinfer = MHATokenToKVPool(head_num=num_kv_heads, head_dim=head_size) + + inputs = create_inputs( + head_size=head_size, + batch_size=batch_size, + seq_len=seq_len, + device=device, + dtype=dtype, + num_q_heads=num_q_heads, + num_kv_heads=num_kv_heads, + ) + + query_flashinfer, key_flashinfer = inputs["query"].clone(), inputs["key"].clone() + + bench_fn = lambda: rope_flashinfer.forward_cuda( + inputs["pos_ids"], + query_flashinfer, + key_flashinfer, + fused_set_kv_buffer_arg=( + FusedSetKVBufferArg( + value=inputs["value"], + k_buffer=pool_flashinfer.k_buffer[0].view(-1, num_kv_heads * head_size), + v_buffer=pool_flashinfer.v_buffer[0].view(-1, num_kv_heads * head_size), + k_scale=None, + v_scale=None, + cache_loc=inputs["out_cache_loc"], + ) + if save_kv_cache + else None + ), + ) + + time_s = bench_kineto(bench_fn, kernel_names="BatchQKApplyRotaryPosIds") + return time_s * 1e6 + + +if __name__ == "__main__": + benchmark.run(print_data=True) diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index fde329699..86ef29f24 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -89,7 +89,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.def( "apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, " - "Tensor pos_ids, bool interleave, int cuda_stream) -> ()"); + "Tensor pos_ids, bool interleave, int cuda_stream, " + "Tensor? v, Tensor!? k_buffer, Tensor!? v_buffer, Tensor? kv_cache_loc) -> ()"); m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache); /* diff --git a/sgl-kernel/csrc/elementwise/pos_enc.cuh b/sgl-kernel/csrc/elementwise/pos_enc.cuh new file mode 100644 index 000000000..5388f0e74 --- /dev/null +++ b/sgl-kernel/csrc/elementwise/pos_enc.cuh @@ -0,0 +1,431 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef SGL_POS_ENC_CUH_ +#define SGL_POS_ENC_CUH_ + +#include // upstream + +namespace flashinfer { + +namespace kv_buffer_saver { + +template +__device__ __forceinline__ void prepare( + vec_t& v_vec, + IdType& kv_cache_offset, + DType* v, + IdType* kv_cache_loc, + uint32_t idx, + uint32_t tx, + uint32_t kv_head_idx, + size_t v_stride_n, + size_t v_stride_h) { + kv_cache_offset = kv_cache_loc[idx]; + + DType* v_ptr = v + get_elem_offset_impl(idx, kv_head_idx, 0, v_stride_n, v_stride_h); + v_vec.cast_load(v_ptr + tx * vec_size); +} + +template +__device__ __forceinline__ void save( + IdType& kv_cache_offset, + vec_t& k_vec, + vec_t& v_vec, + DType* k_buffer, + DType* v_buffer, + uint32_t idx, + uint32_t tx, + uint32_t kv_head_idx, + size_t k_buffer_stride_n, + size_t k_buffer_stride_h, + size_t v_buffer_stride_n, + size_t v_buffer_stride_h) { + DType* k_buffer_ptr = + k_buffer + get_elem_offset_impl(kv_cache_offset, kv_head_idx, 0, k_buffer_stride_n, k_buffer_stride_h); + DType* v_buffer_ptr = + v_buffer + get_elem_offset_impl(kv_cache_offset, kv_head_idx, 0, v_buffer_stride_n, v_buffer_stride_h); + k_vec.cast_store(k_buffer_ptr + tx * vec_size); + v_vec.cast_store(v_buffer_ptr + tx * vec_size); +} + +} // namespace kv_buffer_saver + +template < + bool save_kv_cache, + bool interleave, + uint32_t head_dim, + uint32_t vec_size, + uint32_t bdx, + typename DType, + typename IdType> +__global__ void BatchQKApplyRotaryPosIdsCosSinCacheEnhancedHeadParallelismKernel( + DType* q, + DType* k, + DType* v, + DType* q_rope, + DType* k_rope, + DType* k_buffer, + DType* v_buffer, + float* __restrict__ cos_sin_cache, + IdType* __restrict__ pos_ids, + uint32_t nnz, + uint32_t num_qo_heads, + uint32_t num_kv_heads, + uint32_t rotary_dim, + size_t q_stride_n, + size_t q_stride_h, + size_t k_stride_n, + size_t k_stride_h, + size_t v_stride_n, + size_t v_stride_h, + size_t q_rope_stride_n, + size_t q_rope_stride_h, + size_t k_rope_stride_n, + size_t k_rope_stride_h, + size_t k_buffer_stride_n, + size_t k_buffer_stride_h, + size_t v_buffer_stride_n, + size_t v_buffer_stride_h, + IdType* __restrict__ kv_cache_loc) { + uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y; + uint32_t by = blockIdx.y; + const uint32_t bdy = blockDim.y; + + vec_t cos, sin; + if (bx * bdy + ty < nnz) { + const uint32_t idx = bx * bdy + ty; + const IdType pos = pos_ids[idx]; + + const int half_rotary_dim = rotary_dim / 2; + + // 1. if interleave: + // - cos = cos_sin_cache[pos_id][tx * vec_size // 2] + // - sin = cos_sin_cache[pos_id][(rot_dim // 2) + tx * vec_size // 2] + // 2. if not interleave + // - cos = cos_cache[pos_id][(tx * vec_size) % (rot_dim // 2)] + // - sin = sin_cache[pos_id][(rot_dim // 2) + (tx * vec_size) % (rot_dim // 2)] + if (tx * vec_size < rotary_dim) { + int sin_offset = rotary_dim / 2; + int vec_idx; + if constexpr (interleave) { + vec_idx = (tx * vec_size) / 2; // Force integer division + } else { + vec_idx = (tx * vec_size) % half_rotary_dim; // Use half_rotary_dim + } + cos.load(cos_sin_cache + (pos * rotary_dim) + vec_idx); + sin.load(cos_sin_cache + (pos * rotary_dim) + (sin_offset + vec_idx)); + } + + if (by < num_qo_heads) { + uint32_t qo_head_idx = by; + DType* q_ptr = q + get_elem_offset_impl(idx, qo_head_idx, 0, q_stride_n, q_stride_h); + DType* q_rope_ptr = q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h); + vec_t q_vec; + if constexpr (interleave) { + q_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half(q_ptr, cos, sin, rotary_dim); + } else { + q_vec = vec_apply_llama_rope_cos_sin(q_ptr, cos, sin, rotary_dim); + } + q_vec.cast_store(q_rope_ptr + tx * vec_size); + } else { + uint32_t kv_head_idx = by - num_qo_heads; + DType* k_ptr = k + get_elem_offset_impl(idx, kv_head_idx, 0, k_stride_n, k_stride_h); + + DType* k_rope_ptr = k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h); + + vec_t v_vec; + IdType kv_cache_offset; + if constexpr (save_kv_cache) { + kv_buffer_saver::prepare( + v_vec, kv_cache_offset, v, kv_cache_loc, idx, tx, kv_head_idx, v_stride_n, v_stride_h); + } + + vec_t k_vec; + if constexpr (interleave) { + k_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half(k_ptr, cos, sin, rotary_dim); + } else { + k_vec = vec_apply_llama_rope_cos_sin(k_ptr, cos, sin, rotary_dim); + } + k_vec.cast_store(k_rope_ptr + tx * vec_size); + + if constexpr (save_kv_cache) { + kv_buffer_saver::save( + kv_cache_offset, + k_vec, + v_vec, + k_buffer, + v_buffer, + idx, + tx, + kv_head_idx, + k_buffer_stride_n, + k_buffer_stride_h, + v_buffer_stride_n, + v_buffer_stride_h); + } + } + } +} + +template < + bool save_kv_cache, + bool interleave, + uint32_t head_dim, + uint32_t vec_size, + uint32_t bdx, + typename DType, + typename IdType> +__global__ void BatchQKApplyRotaryPosIdsCosSinCacheEnhancedKernel( + DType* q, + DType* k, + DType* v, + DType* q_rope, + DType* k_rope, + DType* k_buffer, + DType* v_buffer, + float* __restrict__ cos_sin_cache, + IdType* __restrict__ pos_ids, + uint32_t nnz, + uint32_t num_qo_heads, + uint32_t num_kv_heads, + uint32_t rotary_dim, + size_t q_stride_n, + size_t q_stride_h, + size_t k_stride_n, + size_t k_stride_h, + size_t v_stride_n, + size_t v_stride_h, + size_t q_rope_stride_n, + size_t q_rope_stride_h, + size_t k_rope_stride_n, + size_t k_rope_stride_h, + size_t k_buffer_stride_n, + size_t k_buffer_stride_h, + size_t v_buffer_stride_n, + size_t v_buffer_stride_h, + IdType* __restrict__ kv_cache_loc) { + uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y; + const uint32_t bdy = blockDim.y; + + vec_t cos, sin; + if (bx * bdy + ty < nnz) { + const uint32_t idx = bx * bdy + ty; + const IdType pos = pos_ids[idx]; + const int half_rotary_dim = rotary_dim / 2; + + // 1. if interleave: + // - cos = cos_sin_cache[pos_id][tx * vec_size // 2] + // - sin = cos_sin_cache[pos_id][(rot_dim // 2) + tx * vec_size // 2] + // 2. if not interleave + // - cos = cos_cache[pos_id][(tx * vec_size) % (rot_dim // 2)] + // - sin = sin_cache[pos_id][(rot_dim // 2) + (tx * vec_size) % (rot_dim // 2)] + if (tx * vec_size < rotary_dim) { + int sin_offset = rotary_dim / 2; + int vec_idx; + if constexpr (interleave) { + vec_idx = (tx * vec_size) / 2; // Force integer division + } else { + vec_idx = (tx * vec_size) % half_rotary_dim; // Use half_rotary_dim + } + cos.load(cos_sin_cache + (pos * rotary_dim) + vec_idx); + sin.load(cos_sin_cache + (pos * rotary_dim) + (sin_offset + vec_idx)); + } + + // not to unroll the loop, because num head might be large and might lead to worse performance +#pragma unroll 1 + for (uint32_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) { + DType* q_ptr = q + get_elem_offset_impl(idx, qo_head_idx, 0, q_stride_n, q_stride_h); + DType* q_rope_ptr = q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h); + vec_t q_vec; + if constexpr (interleave) { + q_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half(q_ptr, cos, sin, rotary_dim); + } else { + q_vec = vec_apply_llama_rope_cos_sin(q_ptr, cos, sin, rotary_dim); + } + q_vec.cast_store(q_rope_ptr + tx * vec_size); + } + +#pragma unroll 1 + for (uint32_t kv_head_idx = 0; kv_head_idx < num_kv_heads; ++kv_head_idx) { + DType* k_ptr = k + get_elem_offset_impl(idx, kv_head_idx, 0, k_stride_n, k_stride_h); + + DType* k_rope_ptr = k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h); + + vec_t v_vec; + IdType kv_cache_offset; + if constexpr (save_kv_cache) { + kv_buffer_saver::prepare( + v_vec, kv_cache_offset, v, kv_cache_loc, idx, tx, kv_head_idx, v_stride_n, v_stride_h); + } + + vec_t k_vec; + if constexpr (interleave) { + k_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half(k_ptr, cos, sin, rotary_dim); + } else { + k_vec = vec_apply_llama_rope_cos_sin(k_ptr, cos, sin, rotary_dim); + } + k_vec.cast_store(k_rope_ptr + tx * vec_size); + + if constexpr (save_kv_cache) { + kv_buffer_saver::save( + kv_cache_offset, + k_vec, + v_vec, + k_buffer, + v_buffer, + idx, + tx, + kv_head_idx, + k_buffer_stride_n, + k_buffer_stride_h, + v_buffer_stride_n, + v_buffer_stride_h); + } + } + } +} + +#define DISPATCH_SAVE_KV_CACHE(save_kv_cache, SAVE_KV_CACHE, ...) \ + if (save_kv_cache) { \ + const bool SAVE_KV_CACHE = true; \ + __VA_ARGS__ \ + } else { \ + const bool SAVE_KV_CACHE = false; \ + __VA_ARGS__ \ + } + +template +cudaError_t BatchQKApplyRotaryPosIdsCosSinCacheEnhanced( + DType* q, + DType* k, + DType* v, + DType* q_rope, + DType* k_rope, + DType* k_buffer, + DType* v_buffer, + float* cos_sin_cache, + IdType* pos_ids, + uint32_t nnz, + uint32_t num_qo_heads, + uint32_t num_kv_heads, + uint32_t rotary_dim, + uint32_t head_dim, + size_t q_stride_n, + size_t q_stride_h, + size_t k_stride_n, + size_t k_stride_h, + size_t v_stride_n, + size_t v_stride_h, + size_t q_rope_stride_n, + size_t q_rope_stride_h, + size_t k_rope_stride_n, + size_t k_rope_stride_h, + size_t k_buffer_stride_n, + size_t k_buffer_stride_h, + size_t v_buffer_stride_n, + size_t v_buffer_stride_h, + IdType* kv_cache_loc, + bool interleave, + bool save_kv_cache, + cudaStream_t stream = nullptr) { + int dev_id = 0; + int num_sms = 0; + FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); + FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id)); + + DISPATCH_SAVE_KV_CACHE(save_kv_cache, SAVE_KV_CACHE, { + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + // operate on 16 Bytes at a time + constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32); + // how many threads needed per head_dim + constexpr uint32_t bdx = HEAD_DIM / vec_size; + // how many threads needed per block + uint32_t num_threads = std::max(128U, bdx); + // how many tokens can we process in a block + uint32_t bdy = num_threads / bdx; + // how many blocks needed to process all tokens + uint32_t nblks_x = (nnz + bdy - 1) / bdy; + void* args[] = { + (void*)&q, + (void*)&k, + (void*)&v, + (void*)&q_rope, + (void*)&k_rope, + (void*)&k_buffer, + (void*)&v_buffer, + (void*)&cos_sin_cache, + (void*)&pos_ids, + (void*)&nnz, + (void*)&num_qo_heads, + (void*)&num_kv_heads, + (void*)&rotary_dim, + (void*)&q_stride_n, + (void*)&q_stride_h, + (void*)&k_stride_n, + (void*)&k_stride_h, + (void*)&v_stride_n, + (void*)&v_stride_h, + (void*)&q_rope_stride_n, + (void*)&q_rope_stride_h, + (void*)&k_rope_stride_n, + (void*)&k_rope_stride_h, + (void*)&k_buffer_stride_n, + (void*)&k_buffer_stride_h, + (void*)&v_buffer_stride_n, + (void*)&v_buffer_stride_h, + (void*)&kv_cache_loc}; + auto kernel_0 = BatchQKApplyRotaryPosIdsCosSinCacheEnhancedKernel< + SAVE_KV_CACHE, + INTERLEAVE, + HEAD_DIM, + vec_size, + bdx, + DType, + IdType>; + + int num_blocks_per_sm_0 = 0; + FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &num_blocks_per_sm_0, kernel_0, num_threads, /*smem_size=*/0)); + uint32_t num_ctas_0 = num_blocks_per_sm_0 * num_sms; + + if ((nnz + bdy - 1) / bdy >= num_ctas_0) { + dim3 nblks(nblks_x); + dim3 nthrs(bdx, bdy); + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel_0, nblks, nthrs, args, 0, stream)); + } else { + dim3 nblks(nblks_x, num_qo_heads + num_kv_heads); + dim3 nthrs(bdx, bdy); + auto kernel_1 = BatchQKApplyRotaryPosIdsCosSinCacheEnhancedHeadParallelismKernel< + SAVE_KV_CACHE, + INTERLEAVE, + HEAD_DIM, + vec_size, + bdx, + DType, + IdType>; + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel_1, nblks, nthrs, args, 0, stream)); + } + }); + }); + }); + + return cudaSuccess; +} + +} // namespace flashinfer + +#endif // SGL_POS_ENC_CUH_ diff --git a/sgl-kernel/csrc/elementwise/rope.cu b/sgl-kernel/csrc/elementwise/rope.cu index 4274acf43..41cad7dd4 100644 --- a/sgl-kernel/csrc/elementwise/rope.cu +++ b/sgl-kernel/csrc/elementwise/rope.cu @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include +#include "pos_enc.cuh" #include "pytorch_extension_utils.h" using namespace flashinfer; @@ -27,9 +27,37 @@ void apply_rope_pos_ids_cos_sin_cache( at::Tensor cos_sin_cache, at::Tensor pos_ids, bool interleave, - int64_t cuda_stream) { + int64_t cuda_stream, + const std::optional& v, + const std::optional& k_buffer, + const std::optional& v_buffer, + const std::optional& kv_cache_loc) { CHECK_LAST_DIM_CONTIGUOUS(q); CHECK_LAST_DIM_CONTIGUOUS(k); + + const bool save_kv_cache = v.has_value(); + if (save_kv_cache) { + TORCH_CHECK(v.has_value()); + TORCH_CHECK(k_buffer.has_value()); + TORCH_CHECK(v_buffer.has_value()); + TORCH_CHECK(kv_cache_loc.has_value()); + CHECK_LAST_DIM_CONTIGUOUS(v.value()); + CHECK_LAST_DIM_CONTIGUOUS(k_buffer.value()); + CHECK_LAST_DIM_CONTIGUOUS(v_buffer.value()); + CHECK_DIM(3, k_buffer.value()); // k_buffer: (nnz, H_K, D) + CHECK_DIM(3, v_buffer.value()); // v_buffer: (nnz, H_V, D) + CHECK_DIM(3, v.value()); // v: (nnz, H_V, D) + CHECK_DIM(1, kv_cache_loc.value()); // v: (n) + CHECK_INPUT(kv_cache_loc.value()); + } + size_t k_buffer_stride_n = save_kv_cache ? k_buffer->stride(0) : 0; + size_t k_buffer_stride_h = save_kv_cache ? k_buffer->stride(1) : 0; + size_t v_buffer_stride_n = save_kv_cache ? v_buffer->stride(0) : 0; + size_t v_buffer_stride_h = save_kv_cache ? v_buffer->stride(1) : 0; + size_t v_stride_n = save_kv_cache ? v->stride(0) : 0; + size_t v_stride_h = save_kv_cache ? v->stride(1) : 0; + auto kv_cache_loc_ptr = save_kv_cache ? static_cast(kv_cache_loc->data_ptr()) : nullptr; + CHECK_INPUT(cos_sin_cache); CHECK_INPUT(pos_ids); auto device = q.device(); @@ -38,6 +66,7 @@ void apply_rope_pos_ids_cos_sin_cache( CHECK_EQ(pos_ids.device(), device); CHECK_DIM(3, q); // q: (nnz, H_Q, D) CHECK_DIM(3, k); // k: (nnz, H_K, D) + // cos_sin_cache: (max_seq_len, R) // First half of R is cos, second half is sin CHECK_DIM(2, cos_sin_cache); @@ -52,6 +81,7 @@ void apply_rope_pos_ids_cos_sin_cache( size_t q_stride_h = q.stride(1); size_t k_stride_n = k.stride(0); size_t k_stride_h = k.stride(1); + size_t q_rope_stride_n = q_rope.stride(0); size_t q_rope_stride_h = q_rope.stride(1); size_t k_rope_stride_n = k_rope.stride(0); @@ -59,31 +89,73 @@ void apply_rope_pos_ids_cos_sin_cache( cudaStream_t stream = reinterpret_cast(cuda_stream); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { - cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCache( - static_cast(q.data_ptr()), - static_cast(k.data_ptr()), - static_cast(q_rope.data_ptr()), - static_cast(k_rope.data_ptr()), - static_cast(cos_sin_cache.data_ptr()), - static_cast(pos_ids.data_ptr()), - nnz, - num_qo_heads, - num_kv_heads, - rotary_dim, - head_dim, - q_stride_n, - q_stride_h, - k_stride_n, - k_stride_h, - q_rope_stride_n, - q_rope_stride_h, - k_rope_stride_n, - k_rope_stride_h, - interleave, - stream); - TORCH_CHECK( - status == cudaSuccess, - "BatchQKApplyRotaryPosIdsCosSinCache failed with error code " + std::string(cudaGetErrorString(status))); + // TODO temporarily only use `BatchQKApplyRotaryPosIdsCosSinCacheEnhanced` when save_kv_cache + // to avoid changing original code path; but this branch is feature-complete and should switch to this later + if (save_kv_cache) { + cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCacheEnhanced( + static_cast(q.data_ptr()), + static_cast(k.data_ptr()), + save_kv_cache ? static_cast(v->data_ptr()) : nullptr, + static_cast(q_rope.data_ptr()), + static_cast(k_rope.data_ptr()), + save_kv_cache ? static_cast(k_buffer->data_ptr()) : nullptr, + save_kv_cache ? static_cast(v_buffer->data_ptr()) : nullptr, + static_cast(cos_sin_cache.data_ptr()), + static_cast(pos_ids.data_ptr()), + nnz, + num_qo_heads, + num_kv_heads, + rotary_dim, + head_dim, + q_stride_n, + q_stride_h, + k_stride_n, + k_stride_h, + v_stride_n, + v_stride_h, + q_rope_stride_n, + q_rope_stride_h, + k_rope_stride_n, + k_rope_stride_h, + k_buffer_stride_n, + k_buffer_stride_h, + v_buffer_stride_n, + v_buffer_stride_h, + kv_cache_loc_ptr, + interleave, + save_kv_cache, + stream); + TORCH_CHECK( + status == cudaSuccess, + "BatchQKApplyRotaryPosIdsCosSinCacheEnhanced failed with error code " + + std::string(cudaGetErrorString(status))); + } else { + cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCache( + static_cast(q.data_ptr()), + static_cast(k.data_ptr()), + static_cast(q_rope.data_ptr()), + static_cast(k_rope.data_ptr()), + static_cast(cos_sin_cache.data_ptr()), + static_cast(pos_ids.data_ptr()), + nnz, + num_qo_heads, + num_kv_heads, + rotary_dim, + head_dim, + q_stride_n, + q_stride_h, + k_stride_n, + k_stride_h, + q_rope_stride_n, + q_rope_stride_h, + k_rope_stride_n, + k_rope_stride_h, + interleave, + stream); + TORCH_CHECK( + status == cudaSuccess, + "BatchQKApplyRotaryPosIdsCosSinCache failed with error code " + std::string(cudaGetErrorString(status))); + } return true; }); } diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 9e71cd8c8..c007251cd 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -150,7 +150,11 @@ void apply_rope_pos_ids_cos_sin_cache( at::Tensor cos_sin_cache, at::Tensor pos_ids, bool interleave, - int64_t cuda_stream); + int64_t cuda_stream, + const std::optional& v, + const std::optional& k_buffer, + const std::optional& v_buffer, + const std::optional& kv_cache_loc); #ifdef USE_ROCM void gelu_quick(at::Tensor& out, const at::Tensor& input); diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 2a4656aea..faeff9240 100755 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -21,6 +21,7 @@ from sgl_kernel.attention import ( ) from sgl_kernel.cutlass_moe import cutlass_w4a8_moe_mm, get_cutlass_w4a8_moe_mm_data from sgl_kernel.elementwise import ( + FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace, fused_add_rmsnorm, gelu_and_mul, diff --git a/sgl-kernel/python/sgl_kernel/elementwise.py b/sgl-kernel/python/sgl_kernel/elementwise.py index 01ee71860..aa62d65d4 100644 --- a/sgl-kernel/python/sgl_kernel/elementwise.py +++ b/sgl-kernel/python/sgl_kernel/elementwise.py @@ -1,4 +1,5 @@ -from typing import Optional +from dataclasses import dataclass +from typing import Any, Optional import torch from sgl_kernel.utils import get_cuda_stream, is_hopper_arch @@ -237,6 +238,31 @@ if torch.version.hip is not None: return out +@dataclass +class FusedSetKVBufferArg: + """ + value : Optional[torch.Tensor] + Value tensor, shape: ``(nnz, num_v_heads * head_size)``. + k_buffer : Optional[torch.Tensor] + Buffer for keys, shape: ``(nnz, num_k_heads * head_size)``. + v_buffer : Optional[torch.Tensor] + Buffer for values, shape: ``(nnz, num_v_heads * head_size)``. + k_scale : Optional[float] + Scale factor for keys. + v_scale : Optional[float] + Scale factor for values. + cache_loc : Optional[torch.Tensor] + Cache location tensor, used for indexing kv cache. + """ + + value: torch.Tensor + k_buffer: torch.Tensor + v_buffer: torch.Tensor + k_scale: Optional[float] + v_scale: Optional[float] + cache_loc: torch.Tensor + + def apply_rope_with_cos_sin_cache_inplace( positions: torch.Tensor, query: torch.Tensor, @@ -244,6 +270,7 @@ def apply_rope_with_cos_sin_cache_inplace( head_size: int, cos_sin_cache: torch.Tensor, is_neox: bool = True, + fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None, ) -> None: r""" Apply rotary embedding to keys and queries with precomputed cos/sin values. @@ -270,6 +297,9 @@ def apply_rope_with_cos_sin_cache_inplace( * If ``False``, the last dimension of the query/key tensor is interleaved, i.e., we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``. + fused_set_kv_buffer_arg : FusedSetKVBufferArg + Fuse the set-kv-buffer operation into this kernel + Note ---- The rotary dimension is determined by the cosine cache and sine cache. @@ -277,13 +307,41 @@ def apply_rope_with_cos_sin_cache_inplace( if cos_sin_cache.dtype != torch.float32: raise ValueError("cos_sin_cache should be float32") + if (a := fused_set_kv_buffer_arg) is not None: + assert a.k_scale is None, "k_scale is not yet supported" + assert a.v_scale is None, "v_scale is not yet supported" + assert a.cache_loc.dtype == torch.int64, f"{a.cache_loc.dtype=}" + + def _view_3d(x): + return x.view(x.shape[0], -1, head_size) + torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default( - query.view(query.shape[0], -1, head_size), - key.view(key.shape[0], -1, head_size), - query.view(query.shape[0], -1, head_size), - key.view(key.shape[0], -1, head_size), + _view_3d(query), + _view_3d(key), + _view_3d(query), + _view_3d(key), cos_sin_cache, positions.long(), (not is_neox), get_cuda_stream(), + ( + _view_3d(fused_set_kv_buffer_arg.value) + if fused_set_kv_buffer_arg is not None + else None + ), + ( + _view_3d(fused_set_kv_buffer_arg.k_buffer) + if fused_set_kv_buffer_arg is not None + else None + ), + ( + _view_3d(fused_set_kv_buffer_arg.v_buffer) + if fused_set_kv_buffer_arg is not None + else None + ), + ( + fused_set_kv_buffer_arg.cache_loc + if fused_set_kv_buffer_arg is not None + else None + ), ) diff --git a/sgl-kernel/python/sgl_kernel/testing/__init__.py b/sgl-kernel/python/sgl_kernel/testing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/sgl-kernel/python/sgl_kernel/testing/rotary_embedding.py b/sgl-kernel/python/sgl_kernel/testing/rotary_embedding.py new file mode 100644 index 000000000..e26208048 --- /dev/null +++ b/sgl-kernel/python/sgl_kernel/testing/rotary_embedding.py @@ -0,0 +1,217 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +import pytest +import torch +from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace + + +# vLLM torch native +def _apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, +) -> torch.Tensor: + """ + Args: + x: [num_tokens, num_heads, head_size] + cos: [num_tokens, head_size // 2] + sin: [num_tokens, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style rotary + positional embeddings. + """ + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) + + +class RotaryEmbedding(torch.nn.Module): + # Reference: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.dtype = dtype + + cache = self._compute_cos_sin_cache() + self.cos_sin_cache: torch.Tensor + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + ) + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """A PyTorch-native implementation of forward().""" + if offsets is not None: + positions = positions + offsets + + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions) + + # Modification: float32 is required for the rotary embedding to work correctly + query = query.to(torch.float32) + key = key.to(torch.float32) + + cos, sin = cos_sin.chunk(2, dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + + # Modification: convert to the correct dtype + query = query.to(self.dtype) + key = key.to(self.dtype) + return query, key + + +class FlashInferRotaryEmbedding(RotaryEmbedding): + def forward_cuda( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + apply_rope_with_cos_sin_cache_inplace( + positions=positions, + query=query, + key=key, + fused_set_kv_buffer_arg=fused_set_kv_buffer_arg, + head_size=self.head_size, + cos_sin_cache=self.cos_sin_cache, + is_neox=self.is_neox_style, + ) + + return query, key + + +class MHATokenToKVPool: + KV_POOL_SIZE = 16384 + + def __init__( + self, + head_num: int, + head_dim: int, + ): + self.head_num = head_num + self.head_dim = head_dim + self.size = MHATokenToKVPool.KV_POOL_SIZE + self.page_size = 1 + self.store_dtype = torch.bfloat16 + self.device = "cuda" + self.layer_num = 1 + self.start_layer = 0 + self._create_buffers() + + def _create_buffers(self): + self.k_buffer = [ + torch.zeros( + (self.size + self.page_size, self.head_num, self.head_dim), + dtype=self.store_dtype, + device=self.device, + ) + for _ in range(self.layer_num) + ] + self.v_buffer = [ + torch.zeros( + (self.size + self.page_size, self.head_num, self.head_dim), + dtype=self.store_dtype, + device=self.device, + ) + for _ in range(self.layer_num) + ] + + def set_kv_buffer( + self, + loc: torch.Tensor, + cache_k: torch.Tensor, + cache_v: torch.Tensor, + ): + layer_id = 0 + self.k_buffer[layer_id - self.start_layer][loc] = cache_k + self.v_buffer[layer_id - self.start_layer][loc] = cache_v + + +def create_inputs( + head_size: int, + batch_size: int, + seq_len: int, + device, + dtype: torch.dtype, + num_q_heads: int, + num_kv_heads: int, +): + pos_ids = torch.arange(seq_len, device=device).repeat(batch_size) + query = torch.randn( + batch_size * seq_len, num_q_heads * head_size, dtype=dtype, device=device + ) + key = torch.randn( + batch_size * seq_len, num_kv_heads * head_size, dtype=dtype, device=device + ) + value = torch.randn( + batch_size * seq_len, num_kv_heads * head_size, dtype=dtype, device=device + ) + out_cache_loc = torch.randperm( + MHATokenToKVPool.KV_POOL_SIZE, dtype=torch.int64, device=device + )[: batch_size * seq_len].clone() + + return dict( + pos_ids=pos_ids, query=query, key=key, value=value, out_cache_loc=out_cache_loc + ) diff --git a/sgl-kernel/tests/test_rotary_embedding.py b/sgl-kernel/tests/test_rotary_embedding.py index 539b51d84..d9f9364b0 100644 --- a/sgl-kernel/tests/test_rotary_embedding.py +++ b/sgl-kernel/tests/test_rotary_embedding.py @@ -2,153 +2,51 @@ from typing import Any, Dict, List, Optional, Tuple, Union import pytest import torch -from sgl_kernel import apply_rope_with_cos_sin_cache_inplace - - -# vLLM torch native -def _apply_rotary_emb( - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - is_neox_style: bool, -) -> torch.Tensor: - """ - Args: - x: [num_tokens, num_heads, head_size] - cos: [num_tokens, head_size // 2] - sin: [num_tokens, head_size // 2] - is_neox_style: Whether to use the Neox-style or GPT-J-style rotary - positional embeddings. - """ - cos = cos.unsqueeze(-2).to(x.dtype) - sin = sin.unsqueeze(-2).to(x.dtype) - if is_neox_style: - x1, x2 = torch.chunk(x, 2, dim=-1) - else: - x1 = x[..., ::2] - x2 = x[..., 1::2] - o1 = x1 * cos - x2 * sin - o2 = x2 * cos + x1 * sin - if is_neox_style: - return torch.cat((o1, o2), dim=-1) - else: - return torch.stack((o1, o2), dim=-1).flatten(-2) - - -class RotaryEmbedding(torch.nn.Module): - # Reference: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py - def __init__( - self, - head_size: int, - rotary_dim: int, - max_position_embeddings: int, - base: int, - is_neox_style: bool, - dtype: torch.dtype, - ) -> None: - super().__init__() - self.head_size = head_size - self.rotary_dim = rotary_dim - self.max_position_embeddings = max_position_embeddings - self.base = base - self.is_neox_style = is_neox_style - self.dtype = dtype - - cache = self._compute_cos_sin_cache() - self.cos_sin_cache: torch.Tensor - self.register_buffer("cos_sin_cache", cache, persistent=False) - - def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: - inv_freq = 1.0 / ( - base - ** ( - torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim - ) - ) - return inv_freq - - def _compute_cos_sin_cache(self) -> torch.Tensor: - """Compute the cos and sin cache.""" - inv_freq = self._compute_inv_freq(self.base) - t = torch.arange(self.max_position_embeddings, dtype=torch.float) - - freqs = torch.einsum("i,j -> ij", t, inv_freq) - cos = freqs.cos() - sin = freqs.sin() - cache = torch.cat((cos, sin), dim=-1) - return cache - - def forward_native( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """A PyTorch-native implementation of forward().""" - if offsets is not None: - positions = positions + offsets - - positions = positions.flatten() - num_tokens = positions.shape[0] - cos_sin = self.cos_sin_cache.index_select(0, positions) - - # Modification: float32 is required for the rotary embedding to work correctly - query = query.to(torch.float32) - key = key.to(torch.float32) - - cos, sin = cos_sin.chunk(2, dim=-1) - - query_shape = query.shape - query = query.view(num_tokens, -1, self.head_size) - query_rot = query[..., : self.rotary_dim] - query_pass = query[..., self.rotary_dim :] - query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) - query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) - - key_shape = key.shape - key = key.view(num_tokens, -1, self.head_size) - key_rot = key[..., : self.rotary_dim] - key_pass = key[..., self.rotary_dim :] - key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) - key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) - - # Modification: convert to the correct dtype - query = query.to(self.dtype) - key = key.to(self.dtype) - return query, key - - -class FlashInferRotaryEmbedding(RotaryEmbedding): - def forward_cuda( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - - apply_rope_with_cos_sin_cache_inplace( - positions=positions, - query=query, - key=key, - head_size=self.head_size, - cos_sin_cache=self.cos_sin_cache, - is_neox=self.is_neox_style, - ) - - return query, key +from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace +from sgl_kernel.testing.rotary_embedding import ( + FlashInferRotaryEmbedding, + MHATokenToKVPool, + RotaryEmbedding, + create_inputs, +) @pytest.mark.parametrize( - "head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, batch_size, seq_len, num_q_heads, num_kv_heads", + "head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, batch_size, seq_len, num_q_heads, num_kv_heads, save_kv_cache", [ - (64, 64, 32, 8000, True, torch.bfloat16, "cuda", 32, 32, 1, 1), - (256, 128, 4096, 10000, True, torch.bfloat16, "cuda", 2, 512, 4, 2), - (512, 128, 311, 10000, True, torch.bfloat16, "cuda", 3, 39, 4, 2), - (128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 32, 8), - (128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 16, 4), - (512, 128, 311, 10000, False, torch.bfloat16, "cuda", 3, 39, 4, 2), + # GPT-OSS cases + *[ + ( + 64, + 64, + 4096, + 8000, + True, + torch.bfloat16, + "cuda", + batch_size, + seq_len, + 64, + 8, + save_kv_cache, + ) + for batch_size, seq_len in ( + (1, 1), + (32, 1), + (128, 1), + (512, 1), + (2, 512), + (4, 4096), + ) + for save_kv_cache in (False, True) + ], + # Other cases + (64, 64, 32, 8000, True, torch.bfloat16, "cuda", 32, 32, 1, 1, False), + (256, 128, 4096, 10000, True, torch.bfloat16, "cuda", 2, 512, 4, 2, False), + (512, 128, 311, 10000, True, torch.bfloat16, "cuda", 3, 39, 4, 2, False), + (128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 32, 8, False), + (128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 16, 4, False), + (512, 128, 311, 10000, False, torch.bfloat16, "cuda", 3, 39, 4, 2, False), ], ) def test_correctness( @@ -163,34 +61,77 @@ def test_correctness( seq_len: int, num_q_heads: int, num_kv_heads: int, + save_kv_cache: bool, ): - rope_ref = RotaryEmbedding( - head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype - ).to(device) - rope_flashinfer = FlashInferRotaryEmbedding( - head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype - ).to(device) - - pos_ids = torch.arange(seq_len, device=device).repeat(batch_size) - query = torch.randn( - batch_size * seq_len, num_q_heads * head_size, dtype=dtype, device=device - ) - key = torch.randn( - batch_size * seq_len, num_kv_heads * head_size, dtype=dtype, device=device + config = dict( + head_size=head_size, + rotary_dim=rotary_dim, + max_position_embeddings=max_position_embeddings, + base=base, + is_neox_style=is_neox_style, + dtype=dtype, ) - query_ref, key_ref = query.clone(), key.clone() - query_flashinfer, key_flashinfer = query.clone(), key.clone() + rope_ref = RotaryEmbedding(**config).to(device) + rope_flashinfer = FlashInferRotaryEmbedding(**config).to(device) + + inputs = create_inputs( + head_size=head_size, + batch_size=batch_size, + seq_len=seq_len, + device=device, + dtype=dtype, + num_q_heads=num_q_heads, + num_kv_heads=num_kv_heads, + ) + + if save_kv_cache: + pool_ref = MHATokenToKVPool(head_num=num_kv_heads, head_dim=head_size) + pool_flashinfer = MHATokenToKVPool(head_num=num_kv_heads, head_dim=head_size) + + query_ref, key_ref = inputs["query"].clone(), inputs["key"].clone() + query_flashinfer, key_flashinfer = inputs["query"].clone(), inputs["key"].clone() + + query_ref_out, key_ref_out = rope_ref.forward_native( + inputs["pos_ids"], query_ref, key_ref + ) + if save_kv_cache: + pool_ref.set_kv_buffer( + loc=inputs["out_cache_loc"], + cache_k=key_ref_out.view(-1, num_kv_heads, head_size), + cache_v=inputs["value"].view(-1, num_kv_heads, head_size), + ) - query_ref_out, key_ref_out = rope_ref.forward_native(pos_ids, query_ref, key_ref) query_flashinfer_out, key_flashinfer_out = rope_flashinfer.forward_cuda( - pos_ids, query_flashinfer, key_flashinfer + inputs["pos_ids"], + query_flashinfer, + key_flashinfer, + fused_set_kv_buffer_arg=( + FusedSetKVBufferArg( + value=inputs["value"], + k_buffer=pool_flashinfer.k_buffer[0].view(-1, num_kv_heads * head_size), + v_buffer=pool_flashinfer.v_buffer[0].view(-1, num_kv_heads * head_size), + k_scale=None, + v_scale=None, + cache_loc=inputs["out_cache_loc"], + ) + if save_kv_cache + else None + ), ) torch.testing.assert_close( query_ref_out, query_flashinfer_out, atol=1e-2, rtol=1e-2 ) torch.testing.assert_close(key_ref_out, key_flashinfer_out, atol=1e-2, rtol=1e-2) + if save_kv_cache: + for field in ["k_buffer", "v_buffer"]: + x_ref = getattr(pool_ref, field)[0] + x_flashinfer = getattr(pool_flashinfer, field)[0] + torch.testing.assert_close(x_ref, x_flashinfer, atol=1e-2, rtol=1e-2) + nonzero_ref = x_ref != 0 + nonzero_flashinfer = x_ref != 0 + assert torch.all(nonzero_ref == nonzero_flashinfer) if __name__ == "__main__":