Fuse writing KV buffer into rope kernel (part 1: sgl-kernel) (#9077)
This commit is contained in:
137
python/sglang/srt/bench_utils.py
Normal file
137
python/sglang/srt/bench_utils.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import os
|
||||
import sys
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# NOTE copied and modified from DeepGEMM
|
||||
class suppress_stdout_stderr:
|
||||
def __enter__(self):
|
||||
self.outnull_file = open(os.devnull, "w")
|
||||
self.errnull_file = open(os.devnull, "w")
|
||||
|
||||
self.old_stdout_fileno_undup = sys.stdout.fileno()
|
||||
self.old_stderr_fileno_undup = sys.stderr.fileno()
|
||||
|
||||
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
|
||||
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
|
||||
|
||||
self.old_stdout = sys.stdout
|
||||
self.old_stderr = sys.stderr
|
||||
|
||||
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
|
||||
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
|
||||
|
||||
sys.stdout = self.outnull_file
|
||||
sys.stderr = self.errnull_file
|
||||
return self
|
||||
|
||||
def __exit__(self, *_):
|
||||
sys.stdout = self.old_stdout
|
||||
sys.stderr = self.old_stderr
|
||||
|
||||
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
|
||||
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
|
||||
|
||||
os.close(self.old_stdout_fileno)
|
||||
os.close(self.old_stderr_fileno)
|
||||
|
||||
self.outnull_file.close()
|
||||
self.errnull_file.close()
|
||||
|
||||
|
||||
# NOTE copied and modified from DeepGEMM
|
||||
def bench_kineto(
|
||||
fn,
|
||||
kernel_names,
|
||||
num_tests: int = 30,
|
||||
suppress_kineto_output: bool = False,
|
||||
trace_path: str = None,
|
||||
flush_l2: bool = True,
|
||||
with_multiple_kernels: bool = False,
|
||||
):
|
||||
# Conflict with Nsight Systems
|
||||
using_nsys = int(os.environ.get("SGLANG_NSYS_PROFILING", 0))
|
||||
|
||||
# By default, flush L2 with an excessive 8GB memset to give the GPU some (literal) chill time without full idle
|
||||
flush_l2_size = int(8e9 // 4)
|
||||
|
||||
# For some auto-tuning kernels with prints
|
||||
fn()
|
||||
|
||||
# Profile
|
||||
suppress = (
|
||||
suppress_stdout_stderr
|
||||
if suppress_kineto_output and not using_nsys
|
||||
else nullcontext
|
||||
)
|
||||
with suppress():
|
||||
schedule = (
|
||||
torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1)
|
||||
if not using_nsys
|
||||
else None
|
||||
)
|
||||
profiler = (
|
||||
torch.profiler.profile(
|
||||
activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule
|
||||
)
|
||||
if not using_nsys
|
||||
else nullcontext()
|
||||
)
|
||||
with profiler:
|
||||
for i in range(2):
|
||||
for _ in range(num_tests):
|
||||
if flush_l2:
|
||||
torch.empty(
|
||||
flush_l2_size, dtype=torch.int, device="cuda"
|
||||
).zero_()
|
||||
fn()
|
||||
|
||||
if not using_nsys:
|
||||
profiler.step()
|
||||
|
||||
# Return 1 if using Nsight Systems
|
||||
if using_nsys:
|
||||
return 1
|
||||
|
||||
# Parse the profiling table
|
||||
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
|
||||
is_tuple = isinstance(kernel_names, tuple)
|
||||
prof_lines = (
|
||||
profiler.key_averages()
|
||||
.table(sort_by="cuda_time_total", max_name_column_width=100)
|
||||
.split("\n")
|
||||
)
|
||||
kernel_names = (kernel_names,) if isinstance(kernel_names, str) else kernel_names
|
||||
assert all([isinstance(name, str) for name in kernel_names])
|
||||
if not with_multiple_kernels:
|
||||
for name in kernel_names:
|
||||
assert (
|
||||
sum([name in line for line in prof_lines]) == 1
|
||||
), f"Errors of the kernel {name} in the profiling table (table: {prof_lines})"
|
||||
|
||||
# Save chrome traces
|
||||
if trace_path is not None:
|
||||
profiler.export_chrome_trace(trace_path)
|
||||
|
||||
# Return average kernel times
|
||||
units = {"ms": 1e3, "us": 1e6}
|
||||
kernel_times = []
|
||||
for name in kernel_names:
|
||||
total_time = 0
|
||||
total_num = 0
|
||||
for line in prof_lines:
|
||||
if name in line:
|
||||
time_str = line.split()[-2]
|
||||
num_str = line.split()[-1]
|
||||
for unit, scale in units.items():
|
||||
if unit in time_str:
|
||||
total_time += (
|
||||
float(time_str.replace(unit, "")) / scale * int(num_str)
|
||||
)
|
||||
total_num += int(num_str)
|
||||
break
|
||||
kernel_times.append(total_time / total_num)
|
||||
|
||||
return tuple(kernel_times) if is_tuple else kernel_times[0]
|
||||
96
sgl-kernel/benchmark/bench_rotary_embedding.py
Normal file
96
sgl-kernel/benchmark/bench_rotary_embedding.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel import FusedSetKVBufferArg
|
||||
from sgl_kernel.testing.rotary_embedding import (
|
||||
FlashInferRotaryEmbedding,
|
||||
MHATokenToKVPool,
|
||||
RotaryEmbedding,
|
||||
create_inputs,
|
||||
)
|
||||
|
||||
from sglang.srt.bench_utils import bench_kineto
|
||||
|
||||
configs = [
|
||||
(batch_size, seq_len, save_kv_cache)
|
||||
for batch_size, seq_len in (
|
||||
(1, 1),
|
||||
(32, 1),
|
||||
(128, 1),
|
||||
(512, 1),
|
||||
(2, 512),
|
||||
(4, 4096),
|
||||
)
|
||||
for save_kv_cache in (False, True)
|
||||
]
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size", "seq_len", "save_kv_cache"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["sglang"],
|
||||
line_names=["SGL Kernel"],
|
||||
styles=[("green", "-")],
|
||||
ylabel="us",
|
||||
plot_name="bench_rotary_embedding",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, seq_len, save_kv_cache, provider):
|
||||
device = torch.device("cuda")
|
||||
|
||||
num_q_heads = 32
|
||||
num_kv_heads = 8
|
||||
head_size = 64
|
||||
dtype = torch.bfloat16
|
||||
|
||||
config = dict(
|
||||
head_size=head_size,
|
||||
rotary_dim=64,
|
||||
max_position_embeddings=4096,
|
||||
base=8000,
|
||||
is_neox_style=True,
|
||||
dtype=dtype,
|
||||
)
|
||||
rope_flashinfer = FlashInferRotaryEmbedding(**config).to(device)
|
||||
pool_flashinfer = MHATokenToKVPool(head_num=num_kv_heads, head_dim=head_size)
|
||||
|
||||
inputs = create_inputs(
|
||||
head_size=head_size,
|
||||
batch_size=batch_size,
|
||||
seq_len=seq_len,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
num_q_heads=num_q_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
)
|
||||
|
||||
query_flashinfer, key_flashinfer = inputs["query"].clone(), inputs["key"].clone()
|
||||
|
||||
bench_fn = lambda: rope_flashinfer.forward_cuda(
|
||||
inputs["pos_ids"],
|
||||
query_flashinfer,
|
||||
key_flashinfer,
|
||||
fused_set_kv_buffer_arg=(
|
||||
FusedSetKVBufferArg(
|
||||
value=inputs["value"],
|
||||
k_buffer=pool_flashinfer.k_buffer[0].view(-1, num_kv_heads * head_size),
|
||||
v_buffer=pool_flashinfer.v_buffer[0].view(-1, num_kv_heads * head_size),
|
||||
k_scale=None,
|
||||
v_scale=None,
|
||||
cache_loc=inputs["out_cache_loc"],
|
||||
)
|
||||
if save_kv_cache
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
time_s = bench_kineto(bench_fn, kernel_names="BatchQKApplyRotaryPosIds")
|
||||
return time_s * 1e6
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark.run(print_data=True)
|
||||
@@ -89,7 +89,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
|
||||
m.def(
|
||||
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
|
||||
"Tensor pos_ids, bool interleave, int cuda_stream) -> ()");
|
||||
"Tensor pos_ids, bool interleave, int cuda_stream, "
|
||||
"Tensor? v, Tensor!? k_buffer, Tensor!? v_buffer, Tensor? kv_cache_loc) -> ()");
|
||||
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);
|
||||
|
||||
/*
|
||||
|
||||
431
sgl-kernel/csrc/elementwise/pos_enc.cuh
Normal file
431
sgl-kernel/csrc/elementwise/pos_enc.cuh
Normal file
@@ -0,0 +1,431 @@
|
||||
/*
|
||||
* Copyright (c) 2023 by FlashInfer team.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef SGL_POS_ENC_CUH_
|
||||
#define SGL_POS_ENC_CUH_
|
||||
|
||||
#include <flashinfer/pos_enc.cuh> // upstream
|
||||
|
||||
namespace flashinfer {
|
||||
|
||||
namespace kv_buffer_saver {
|
||||
|
||||
template <typename DType, typename IdType, uint32_t vec_size>
|
||||
__device__ __forceinline__ void prepare(
|
||||
vec_t<float, vec_size>& v_vec,
|
||||
IdType& kv_cache_offset,
|
||||
DType* v,
|
||||
IdType* kv_cache_loc,
|
||||
uint32_t idx,
|
||||
uint32_t tx,
|
||||
uint32_t kv_head_idx,
|
||||
size_t v_stride_n,
|
||||
size_t v_stride_h) {
|
||||
kv_cache_offset = kv_cache_loc[idx];
|
||||
|
||||
DType* v_ptr = v + get_elem_offset_impl(idx, kv_head_idx, 0, v_stride_n, v_stride_h);
|
||||
v_vec.cast_load(v_ptr + tx * vec_size);
|
||||
}
|
||||
|
||||
template <typename DType, typename IdType, uint32_t vec_size>
|
||||
__device__ __forceinline__ void save(
|
||||
IdType& kv_cache_offset,
|
||||
vec_t<float, vec_size>& k_vec,
|
||||
vec_t<float, vec_size>& v_vec,
|
||||
DType* k_buffer,
|
||||
DType* v_buffer,
|
||||
uint32_t idx,
|
||||
uint32_t tx,
|
||||
uint32_t kv_head_idx,
|
||||
size_t k_buffer_stride_n,
|
||||
size_t k_buffer_stride_h,
|
||||
size_t v_buffer_stride_n,
|
||||
size_t v_buffer_stride_h) {
|
||||
DType* k_buffer_ptr =
|
||||
k_buffer + get_elem_offset_impl(kv_cache_offset, kv_head_idx, 0, k_buffer_stride_n, k_buffer_stride_h);
|
||||
DType* v_buffer_ptr =
|
||||
v_buffer + get_elem_offset_impl(kv_cache_offset, kv_head_idx, 0, v_buffer_stride_n, v_buffer_stride_h);
|
||||
k_vec.cast_store(k_buffer_ptr + tx * vec_size);
|
||||
v_vec.cast_store(v_buffer_ptr + tx * vec_size);
|
||||
}
|
||||
|
||||
} // namespace kv_buffer_saver
|
||||
|
||||
template <
|
||||
bool save_kv_cache,
|
||||
bool interleave,
|
||||
uint32_t head_dim,
|
||||
uint32_t vec_size,
|
||||
uint32_t bdx,
|
||||
typename DType,
|
||||
typename IdType>
|
||||
__global__ void BatchQKApplyRotaryPosIdsCosSinCacheEnhancedHeadParallelismKernel(
|
||||
DType* q,
|
||||
DType* k,
|
||||
DType* v,
|
||||
DType* q_rope,
|
||||
DType* k_rope,
|
||||
DType* k_buffer,
|
||||
DType* v_buffer,
|
||||
float* __restrict__ cos_sin_cache,
|
||||
IdType* __restrict__ pos_ids,
|
||||
uint32_t nnz,
|
||||
uint32_t num_qo_heads,
|
||||
uint32_t num_kv_heads,
|
||||
uint32_t rotary_dim,
|
||||
size_t q_stride_n,
|
||||
size_t q_stride_h,
|
||||
size_t k_stride_n,
|
||||
size_t k_stride_h,
|
||||
size_t v_stride_n,
|
||||
size_t v_stride_h,
|
||||
size_t q_rope_stride_n,
|
||||
size_t q_rope_stride_h,
|
||||
size_t k_rope_stride_n,
|
||||
size_t k_rope_stride_h,
|
||||
size_t k_buffer_stride_n,
|
||||
size_t k_buffer_stride_h,
|
||||
size_t v_buffer_stride_n,
|
||||
size_t v_buffer_stride_h,
|
||||
IdType* __restrict__ kv_cache_loc) {
|
||||
uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y;
|
||||
uint32_t by = blockIdx.y;
|
||||
const uint32_t bdy = blockDim.y;
|
||||
|
||||
vec_t<float, vec_size> cos, sin;
|
||||
if (bx * bdy + ty < nnz) {
|
||||
const uint32_t idx = bx * bdy + ty;
|
||||
const IdType pos = pos_ids[idx];
|
||||
|
||||
const int half_rotary_dim = rotary_dim / 2;
|
||||
|
||||
// 1. if interleave:
|
||||
// - cos = cos_sin_cache[pos_id][tx * vec_size // 2]
|
||||
// - sin = cos_sin_cache[pos_id][(rot_dim // 2) + tx * vec_size // 2]
|
||||
// 2. if not interleave
|
||||
// - cos = cos_cache[pos_id][(tx * vec_size) % (rot_dim // 2)]
|
||||
// - sin = sin_cache[pos_id][(rot_dim // 2) + (tx * vec_size) % (rot_dim // 2)]
|
||||
if (tx * vec_size < rotary_dim) {
|
||||
int sin_offset = rotary_dim / 2;
|
||||
int vec_idx;
|
||||
if constexpr (interleave) {
|
||||
vec_idx = (tx * vec_size) / 2; // Force integer division
|
||||
} else {
|
||||
vec_idx = (tx * vec_size) % half_rotary_dim; // Use half_rotary_dim
|
||||
}
|
||||
cos.load(cos_sin_cache + (pos * rotary_dim) + vec_idx);
|
||||
sin.load(cos_sin_cache + (pos * rotary_dim) + (sin_offset + vec_idx));
|
||||
}
|
||||
|
||||
if (by < num_qo_heads) {
|
||||
uint32_t qo_head_idx = by;
|
||||
DType* q_ptr = q + get_elem_offset_impl(idx, qo_head_idx, 0, q_stride_n, q_stride_h);
|
||||
DType* q_rope_ptr = q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h);
|
||||
vec_t<float, vec_size> q_vec;
|
||||
if constexpr (interleave) {
|
||||
q_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
|
||||
} else {
|
||||
q_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
|
||||
}
|
||||
q_vec.cast_store(q_rope_ptr + tx * vec_size);
|
||||
} else {
|
||||
uint32_t kv_head_idx = by - num_qo_heads;
|
||||
DType* k_ptr = k + get_elem_offset_impl(idx, kv_head_idx, 0, k_stride_n, k_stride_h);
|
||||
|
||||
DType* k_rope_ptr = k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h);
|
||||
|
||||
vec_t<float, vec_size> v_vec;
|
||||
IdType kv_cache_offset;
|
||||
if constexpr (save_kv_cache) {
|
||||
kv_buffer_saver::prepare<DType, IdType, vec_size>(
|
||||
v_vec, kv_cache_offset, v, kv_cache_loc, idx, tx, kv_head_idx, v_stride_n, v_stride_h);
|
||||
}
|
||||
|
||||
vec_t<float, vec_size> k_vec;
|
||||
if constexpr (interleave) {
|
||||
k_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
|
||||
} else {
|
||||
k_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
|
||||
}
|
||||
k_vec.cast_store(k_rope_ptr + tx * vec_size);
|
||||
|
||||
if constexpr (save_kv_cache) {
|
||||
kv_buffer_saver::save<DType, IdType, vec_size>(
|
||||
kv_cache_offset,
|
||||
k_vec,
|
||||
v_vec,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
idx,
|
||||
tx,
|
||||
kv_head_idx,
|
||||
k_buffer_stride_n,
|
||||
k_buffer_stride_h,
|
||||
v_buffer_stride_n,
|
||||
v_buffer_stride_h);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
bool save_kv_cache,
|
||||
bool interleave,
|
||||
uint32_t head_dim,
|
||||
uint32_t vec_size,
|
||||
uint32_t bdx,
|
||||
typename DType,
|
||||
typename IdType>
|
||||
__global__ void BatchQKApplyRotaryPosIdsCosSinCacheEnhancedKernel(
|
||||
DType* q,
|
||||
DType* k,
|
||||
DType* v,
|
||||
DType* q_rope,
|
||||
DType* k_rope,
|
||||
DType* k_buffer,
|
||||
DType* v_buffer,
|
||||
float* __restrict__ cos_sin_cache,
|
||||
IdType* __restrict__ pos_ids,
|
||||
uint32_t nnz,
|
||||
uint32_t num_qo_heads,
|
||||
uint32_t num_kv_heads,
|
||||
uint32_t rotary_dim,
|
||||
size_t q_stride_n,
|
||||
size_t q_stride_h,
|
||||
size_t k_stride_n,
|
||||
size_t k_stride_h,
|
||||
size_t v_stride_n,
|
||||
size_t v_stride_h,
|
||||
size_t q_rope_stride_n,
|
||||
size_t q_rope_stride_h,
|
||||
size_t k_rope_stride_n,
|
||||
size_t k_rope_stride_h,
|
||||
size_t k_buffer_stride_n,
|
||||
size_t k_buffer_stride_h,
|
||||
size_t v_buffer_stride_n,
|
||||
size_t v_buffer_stride_h,
|
||||
IdType* __restrict__ kv_cache_loc) {
|
||||
uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y;
|
||||
const uint32_t bdy = blockDim.y;
|
||||
|
||||
vec_t<float, vec_size> cos, sin;
|
||||
if (bx * bdy + ty < nnz) {
|
||||
const uint32_t idx = bx * bdy + ty;
|
||||
const IdType pos = pos_ids[idx];
|
||||
const int half_rotary_dim = rotary_dim / 2;
|
||||
|
||||
// 1. if interleave:
|
||||
// - cos = cos_sin_cache[pos_id][tx * vec_size // 2]
|
||||
// - sin = cos_sin_cache[pos_id][(rot_dim // 2) + tx * vec_size // 2]
|
||||
// 2. if not interleave
|
||||
// - cos = cos_cache[pos_id][(tx * vec_size) % (rot_dim // 2)]
|
||||
// - sin = sin_cache[pos_id][(rot_dim // 2) + (tx * vec_size) % (rot_dim // 2)]
|
||||
if (tx * vec_size < rotary_dim) {
|
||||
int sin_offset = rotary_dim / 2;
|
||||
int vec_idx;
|
||||
if constexpr (interleave) {
|
||||
vec_idx = (tx * vec_size) / 2; // Force integer division
|
||||
} else {
|
||||
vec_idx = (tx * vec_size) % half_rotary_dim; // Use half_rotary_dim
|
||||
}
|
||||
cos.load(cos_sin_cache + (pos * rotary_dim) + vec_idx);
|
||||
sin.load(cos_sin_cache + (pos * rotary_dim) + (sin_offset + vec_idx));
|
||||
}
|
||||
|
||||
// not to unroll the loop, because num head might be large and might lead to worse performance
|
||||
#pragma unroll 1
|
||||
for (uint32_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) {
|
||||
DType* q_ptr = q + get_elem_offset_impl(idx, qo_head_idx, 0, q_stride_n, q_stride_h);
|
||||
DType* q_rope_ptr = q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h);
|
||||
vec_t<float, vec_size> q_vec;
|
||||
if constexpr (interleave) {
|
||||
q_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
|
||||
} else {
|
||||
q_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
|
||||
}
|
||||
q_vec.cast_store(q_rope_ptr + tx * vec_size);
|
||||
}
|
||||
|
||||
#pragma unroll 1
|
||||
for (uint32_t kv_head_idx = 0; kv_head_idx < num_kv_heads; ++kv_head_idx) {
|
||||
DType* k_ptr = k + get_elem_offset_impl(idx, kv_head_idx, 0, k_stride_n, k_stride_h);
|
||||
|
||||
DType* k_rope_ptr = k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h);
|
||||
|
||||
vec_t<float, vec_size> v_vec;
|
||||
IdType kv_cache_offset;
|
||||
if constexpr (save_kv_cache) {
|
||||
kv_buffer_saver::prepare<DType, IdType, vec_size>(
|
||||
v_vec, kv_cache_offset, v, kv_cache_loc, idx, tx, kv_head_idx, v_stride_n, v_stride_h);
|
||||
}
|
||||
|
||||
vec_t<float, vec_size> k_vec;
|
||||
if constexpr (interleave) {
|
||||
k_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
|
||||
} else {
|
||||
k_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
|
||||
}
|
||||
k_vec.cast_store(k_rope_ptr + tx * vec_size);
|
||||
|
||||
if constexpr (save_kv_cache) {
|
||||
kv_buffer_saver::save<DType, IdType, vec_size>(
|
||||
kv_cache_offset,
|
||||
k_vec,
|
||||
v_vec,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
idx,
|
||||
tx,
|
||||
kv_head_idx,
|
||||
k_buffer_stride_n,
|
||||
k_buffer_stride_h,
|
||||
v_buffer_stride_n,
|
||||
v_buffer_stride_h);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define DISPATCH_SAVE_KV_CACHE(save_kv_cache, SAVE_KV_CACHE, ...) \
|
||||
if (save_kv_cache) { \
|
||||
const bool SAVE_KV_CACHE = true; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
const bool SAVE_KV_CACHE = false; \
|
||||
__VA_ARGS__ \
|
||||
}
|
||||
|
||||
template <typename DType, typename IdType>
|
||||
cudaError_t BatchQKApplyRotaryPosIdsCosSinCacheEnhanced(
|
||||
DType* q,
|
||||
DType* k,
|
||||
DType* v,
|
||||
DType* q_rope,
|
||||
DType* k_rope,
|
||||
DType* k_buffer,
|
||||
DType* v_buffer,
|
||||
float* cos_sin_cache,
|
||||
IdType* pos_ids,
|
||||
uint32_t nnz,
|
||||
uint32_t num_qo_heads,
|
||||
uint32_t num_kv_heads,
|
||||
uint32_t rotary_dim,
|
||||
uint32_t head_dim,
|
||||
size_t q_stride_n,
|
||||
size_t q_stride_h,
|
||||
size_t k_stride_n,
|
||||
size_t k_stride_h,
|
||||
size_t v_stride_n,
|
||||
size_t v_stride_h,
|
||||
size_t q_rope_stride_n,
|
||||
size_t q_rope_stride_h,
|
||||
size_t k_rope_stride_n,
|
||||
size_t k_rope_stride_h,
|
||||
size_t k_buffer_stride_n,
|
||||
size_t k_buffer_stride_h,
|
||||
size_t v_buffer_stride_n,
|
||||
size_t v_buffer_stride_h,
|
||||
IdType* kv_cache_loc,
|
||||
bool interleave,
|
||||
bool save_kv_cache,
|
||||
cudaStream_t stream = nullptr) {
|
||||
int dev_id = 0;
|
||||
int num_sms = 0;
|
||||
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
|
||||
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id));
|
||||
|
||||
DISPATCH_SAVE_KV_CACHE(save_kv_cache, SAVE_KV_CACHE, {
|
||||
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
|
||||
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
|
||||
// operate on 16 Bytes at a time
|
||||
constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32);
|
||||
// how many threads needed per head_dim
|
||||
constexpr uint32_t bdx = HEAD_DIM / vec_size;
|
||||
// how many threads needed per block
|
||||
uint32_t num_threads = std::max(128U, bdx);
|
||||
// how many tokens can we process in a block
|
||||
uint32_t bdy = num_threads / bdx;
|
||||
// how many blocks needed to process all tokens
|
||||
uint32_t nblks_x = (nnz + bdy - 1) / bdy;
|
||||
void* args[] = {
|
||||
(void*)&q,
|
||||
(void*)&k,
|
||||
(void*)&v,
|
||||
(void*)&q_rope,
|
||||
(void*)&k_rope,
|
||||
(void*)&k_buffer,
|
||||
(void*)&v_buffer,
|
||||
(void*)&cos_sin_cache,
|
||||
(void*)&pos_ids,
|
||||
(void*)&nnz,
|
||||
(void*)&num_qo_heads,
|
||||
(void*)&num_kv_heads,
|
||||
(void*)&rotary_dim,
|
||||
(void*)&q_stride_n,
|
||||
(void*)&q_stride_h,
|
||||
(void*)&k_stride_n,
|
||||
(void*)&k_stride_h,
|
||||
(void*)&v_stride_n,
|
||||
(void*)&v_stride_h,
|
||||
(void*)&q_rope_stride_n,
|
||||
(void*)&q_rope_stride_h,
|
||||
(void*)&k_rope_stride_n,
|
||||
(void*)&k_rope_stride_h,
|
||||
(void*)&k_buffer_stride_n,
|
||||
(void*)&k_buffer_stride_h,
|
||||
(void*)&v_buffer_stride_n,
|
||||
(void*)&v_buffer_stride_h,
|
||||
(void*)&kv_cache_loc};
|
||||
auto kernel_0 = BatchQKApplyRotaryPosIdsCosSinCacheEnhancedKernel<
|
||||
SAVE_KV_CACHE,
|
||||
INTERLEAVE,
|
||||
HEAD_DIM,
|
||||
vec_size,
|
||||
bdx,
|
||||
DType,
|
||||
IdType>;
|
||||
|
||||
int num_blocks_per_sm_0 = 0;
|
||||
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&num_blocks_per_sm_0, kernel_0, num_threads, /*smem_size=*/0));
|
||||
uint32_t num_ctas_0 = num_blocks_per_sm_0 * num_sms;
|
||||
|
||||
if ((nnz + bdy - 1) / bdy >= num_ctas_0) {
|
||||
dim3 nblks(nblks_x);
|
||||
dim3 nthrs(bdx, bdy);
|
||||
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel_0, nblks, nthrs, args, 0, stream));
|
||||
} else {
|
||||
dim3 nblks(nblks_x, num_qo_heads + num_kv_heads);
|
||||
dim3 nthrs(bdx, bdy);
|
||||
auto kernel_1 = BatchQKApplyRotaryPosIdsCosSinCacheEnhancedHeadParallelismKernel<
|
||||
SAVE_KV_CACHE,
|
||||
INTERLEAVE,
|
||||
HEAD_DIM,
|
||||
vec_size,
|
||||
bdx,
|
||||
DType,
|
||||
IdType>;
|
||||
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel_1, nblks, nthrs, args, 0, stream));
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
} // namespace flashinfer
|
||||
|
||||
#endif // SGL_POS_ENC_CUH_
|
||||
@@ -13,8 +13,8 @@
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <flashinfer/pos_enc.cuh>
|
||||
|
||||
#include "pos_enc.cuh"
|
||||
#include "pytorch_extension_utils.h"
|
||||
|
||||
using namespace flashinfer;
|
||||
@@ -27,9 +27,37 @@ void apply_rope_pos_ids_cos_sin_cache(
|
||||
at::Tensor cos_sin_cache,
|
||||
at::Tensor pos_ids,
|
||||
bool interleave,
|
||||
int64_t cuda_stream) {
|
||||
int64_t cuda_stream,
|
||||
const std::optional<at::Tensor>& v,
|
||||
const std::optional<at::Tensor>& k_buffer,
|
||||
const std::optional<at::Tensor>& v_buffer,
|
||||
const std::optional<at::Tensor>& kv_cache_loc) {
|
||||
CHECK_LAST_DIM_CONTIGUOUS(q);
|
||||
CHECK_LAST_DIM_CONTIGUOUS(k);
|
||||
|
||||
const bool save_kv_cache = v.has_value();
|
||||
if (save_kv_cache) {
|
||||
TORCH_CHECK(v.has_value());
|
||||
TORCH_CHECK(k_buffer.has_value());
|
||||
TORCH_CHECK(v_buffer.has_value());
|
||||
TORCH_CHECK(kv_cache_loc.has_value());
|
||||
CHECK_LAST_DIM_CONTIGUOUS(v.value());
|
||||
CHECK_LAST_DIM_CONTIGUOUS(k_buffer.value());
|
||||
CHECK_LAST_DIM_CONTIGUOUS(v_buffer.value());
|
||||
CHECK_DIM(3, k_buffer.value()); // k_buffer: (nnz, H_K, D)
|
||||
CHECK_DIM(3, v_buffer.value()); // v_buffer: (nnz, H_V, D)
|
||||
CHECK_DIM(3, v.value()); // v: (nnz, H_V, D)
|
||||
CHECK_DIM(1, kv_cache_loc.value()); // v: (n)
|
||||
CHECK_INPUT(kv_cache_loc.value());
|
||||
}
|
||||
size_t k_buffer_stride_n = save_kv_cache ? k_buffer->stride(0) : 0;
|
||||
size_t k_buffer_stride_h = save_kv_cache ? k_buffer->stride(1) : 0;
|
||||
size_t v_buffer_stride_n = save_kv_cache ? v_buffer->stride(0) : 0;
|
||||
size_t v_buffer_stride_h = save_kv_cache ? v_buffer->stride(1) : 0;
|
||||
size_t v_stride_n = save_kv_cache ? v->stride(0) : 0;
|
||||
size_t v_stride_h = save_kv_cache ? v->stride(1) : 0;
|
||||
auto kv_cache_loc_ptr = save_kv_cache ? static_cast<int64_t*>(kv_cache_loc->data_ptr()) : nullptr;
|
||||
|
||||
CHECK_INPUT(cos_sin_cache);
|
||||
CHECK_INPUT(pos_ids);
|
||||
auto device = q.device();
|
||||
@@ -38,6 +66,7 @@ void apply_rope_pos_ids_cos_sin_cache(
|
||||
CHECK_EQ(pos_ids.device(), device);
|
||||
CHECK_DIM(3, q); // q: (nnz, H_Q, D)
|
||||
CHECK_DIM(3, k); // k: (nnz, H_K, D)
|
||||
|
||||
// cos_sin_cache: (max_seq_len, R)
|
||||
// First half of R is cos, second half is sin
|
||||
CHECK_DIM(2, cos_sin_cache);
|
||||
@@ -52,6 +81,7 @@ void apply_rope_pos_ids_cos_sin_cache(
|
||||
size_t q_stride_h = q.stride(1);
|
||||
size_t k_stride_n = k.stride(0);
|
||||
size_t k_stride_h = k.stride(1);
|
||||
|
||||
size_t q_rope_stride_n = q_rope.stride(0);
|
||||
size_t q_rope_stride_h = q_rope.stride(1);
|
||||
size_t k_rope_stride_n = k_rope.stride(0);
|
||||
@@ -59,31 +89,73 @@ void apply_rope_pos_ids_cos_sin_cache(
|
||||
|
||||
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
|
||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
|
||||
cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCache(
|
||||
static_cast<c_type*>(q.data_ptr()),
|
||||
static_cast<c_type*>(k.data_ptr()),
|
||||
static_cast<c_type*>(q_rope.data_ptr()),
|
||||
static_cast<c_type*>(k_rope.data_ptr()),
|
||||
static_cast<float*>(cos_sin_cache.data_ptr()),
|
||||
static_cast<int64_t*>(pos_ids.data_ptr()),
|
||||
nnz,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
rotary_dim,
|
||||
head_dim,
|
||||
q_stride_n,
|
||||
q_stride_h,
|
||||
k_stride_n,
|
||||
k_stride_h,
|
||||
q_rope_stride_n,
|
||||
q_rope_stride_h,
|
||||
k_rope_stride_n,
|
||||
k_rope_stride_h,
|
||||
interleave,
|
||||
stream);
|
||||
TORCH_CHECK(
|
||||
status == cudaSuccess,
|
||||
"BatchQKApplyRotaryPosIdsCosSinCache failed with error code " + std::string(cudaGetErrorString(status)));
|
||||
// TODO temporarily only use `BatchQKApplyRotaryPosIdsCosSinCacheEnhanced` when save_kv_cache
|
||||
// to avoid changing original code path; but this branch is feature-complete and should switch to this later
|
||||
if (save_kv_cache) {
|
||||
cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCacheEnhanced(
|
||||
static_cast<c_type*>(q.data_ptr()),
|
||||
static_cast<c_type*>(k.data_ptr()),
|
||||
save_kv_cache ? static_cast<c_type*>(v->data_ptr()) : nullptr,
|
||||
static_cast<c_type*>(q_rope.data_ptr()),
|
||||
static_cast<c_type*>(k_rope.data_ptr()),
|
||||
save_kv_cache ? static_cast<c_type*>(k_buffer->data_ptr()) : nullptr,
|
||||
save_kv_cache ? static_cast<c_type*>(v_buffer->data_ptr()) : nullptr,
|
||||
static_cast<float*>(cos_sin_cache.data_ptr()),
|
||||
static_cast<int64_t*>(pos_ids.data_ptr()),
|
||||
nnz,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
rotary_dim,
|
||||
head_dim,
|
||||
q_stride_n,
|
||||
q_stride_h,
|
||||
k_stride_n,
|
||||
k_stride_h,
|
||||
v_stride_n,
|
||||
v_stride_h,
|
||||
q_rope_stride_n,
|
||||
q_rope_stride_h,
|
||||
k_rope_stride_n,
|
||||
k_rope_stride_h,
|
||||
k_buffer_stride_n,
|
||||
k_buffer_stride_h,
|
||||
v_buffer_stride_n,
|
||||
v_buffer_stride_h,
|
||||
kv_cache_loc_ptr,
|
||||
interleave,
|
||||
save_kv_cache,
|
||||
stream);
|
||||
TORCH_CHECK(
|
||||
status == cudaSuccess,
|
||||
"BatchQKApplyRotaryPosIdsCosSinCacheEnhanced failed with error code " +
|
||||
std::string(cudaGetErrorString(status)));
|
||||
} else {
|
||||
cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCache(
|
||||
static_cast<c_type*>(q.data_ptr()),
|
||||
static_cast<c_type*>(k.data_ptr()),
|
||||
static_cast<c_type*>(q_rope.data_ptr()),
|
||||
static_cast<c_type*>(k_rope.data_ptr()),
|
||||
static_cast<float*>(cos_sin_cache.data_ptr()),
|
||||
static_cast<int64_t*>(pos_ids.data_ptr()),
|
||||
nnz,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
rotary_dim,
|
||||
head_dim,
|
||||
q_stride_n,
|
||||
q_stride_h,
|
||||
k_stride_n,
|
||||
k_stride_h,
|
||||
q_rope_stride_n,
|
||||
q_rope_stride_h,
|
||||
k_rope_stride_n,
|
||||
k_rope_stride_h,
|
||||
interleave,
|
||||
stream);
|
||||
TORCH_CHECK(
|
||||
status == cudaSuccess,
|
||||
"BatchQKApplyRotaryPosIdsCosSinCache failed with error code " + std::string(cudaGetErrorString(status)));
|
||||
}
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
@@ -150,7 +150,11 @@ void apply_rope_pos_ids_cos_sin_cache(
|
||||
at::Tensor cos_sin_cache,
|
||||
at::Tensor pos_ids,
|
||||
bool interleave,
|
||||
int64_t cuda_stream);
|
||||
int64_t cuda_stream,
|
||||
const std::optional<at::Tensor>& v,
|
||||
const std::optional<at::Tensor>& k_buffer,
|
||||
const std::optional<at::Tensor>& v_buffer,
|
||||
const std::optional<at::Tensor>& kv_cache_loc);
|
||||
|
||||
#ifdef USE_ROCM
|
||||
void gelu_quick(at::Tensor& out, const at::Tensor& input);
|
||||
|
||||
@@ -21,6 +21,7 @@ from sgl_kernel.attention import (
|
||||
)
|
||||
from sgl_kernel.cutlass_moe import cutlass_w4a8_moe_mm, get_cutlass_w4a8_moe_mm_data
|
||||
from sgl_kernel.elementwise import (
|
||||
FusedSetKVBufferArg,
|
||||
apply_rope_with_cos_sin_cache_inplace,
|
||||
fused_add_rmsnorm,
|
||||
gelu_and_mul,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from sgl_kernel.utils import get_cuda_stream, is_hopper_arch
|
||||
@@ -237,6 +238,31 @@ if torch.version.hip is not None:
|
||||
return out
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusedSetKVBufferArg:
|
||||
"""
|
||||
value : Optional[torch.Tensor]
|
||||
Value tensor, shape: ``(nnz, num_v_heads * head_size)``.
|
||||
k_buffer : Optional[torch.Tensor]
|
||||
Buffer for keys, shape: ``(nnz, num_k_heads * head_size)``.
|
||||
v_buffer : Optional[torch.Tensor]
|
||||
Buffer for values, shape: ``(nnz, num_v_heads * head_size)``.
|
||||
k_scale : Optional[float]
|
||||
Scale factor for keys.
|
||||
v_scale : Optional[float]
|
||||
Scale factor for values.
|
||||
cache_loc : Optional[torch.Tensor]
|
||||
Cache location tensor, used for indexing kv cache.
|
||||
"""
|
||||
|
||||
value: torch.Tensor
|
||||
k_buffer: torch.Tensor
|
||||
v_buffer: torch.Tensor
|
||||
k_scale: Optional[float]
|
||||
v_scale: Optional[float]
|
||||
cache_loc: torch.Tensor
|
||||
|
||||
|
||||
def apply_rope_with_cos_sin_cache_inplace(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
@@ -244,6 +270,7 @@ def apply_rope_with_cos_sin_cache_inplace(
|
||||
head_size: int,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool = True,
|
||||
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Apply rotary embedding to keys and queries with precomputed cos/sin values.
|
||||
@@ -270,6 +297,9 @@ def apply_rope_with_cos_sin_cache_inplace(
|
||||
|
||||
* If ``False``, the last dimension of the query/key tensor is interleaved, i.e.,
|
||||
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
|
||||
fused_set_kv_buffer_arg : FusedSetKVBufferArg
|
||||
Fuse the set-kv-buffer operation into this kernel
|
||||
|
||||
Note
|
||||
----
|
||||
The rotary dimension is determined by the cosine cache and sine cache.
|
||||
@@ -277,13 +307,41 @@ def apply_rope_with_cos_sin_cache_inplace(
|
||||
if cos_sin_cache.dtype != torch.float32:
|
||||
raise ValueError("cos_sin_cache should be float32")
|
||||
|
||||
if (a := fused_set_kv_buffer_arg) is not None:
|
||||
assert a.k_scale is None, "k_scale is not yet supported"
|
||||
assert a.v_scale is None, "v_scale is not yet supported"
|
||||
assert a.cache_loc.dtype == torch.int64, f"{a.cache_loc.dtype=}"
|
||||
|
||||
def _view_3d(x):
|
||||
return x.view(x.shape[0], -1, head_size)
|
||||
|
||||
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default(
|
||||
query.view(query.shape[0], -1, head_size),
|
||||
key.view(key.shape[0], -1, head_size),
|
||||
query.view(query.shape[0], -1, head_size),
|
||||
key.view(key.shape[0], -1, head_size),
|
||||
_view_3d(query),
|
||||
_view_3d(key),
|
||||
_view_3d(query),
|
||||
_view_3d(key),
|
||||
cos_sin_cache,
|
||||
positions.long(),
|
||||
(not is_neox),
|
||||
get_cuda_stream(),
|
||||
(
|
||||
_view_3d(fused_set_kv_buffer_arg.value)
|
||||
if fused_set_kv_buffer_arg is not None
|
||||
else None
|
||||
),
|
||||
(
|
||||
_view_3d(fused_set_kv_buffer_arg.k_buffer)
|
||||
if fused_set_kv_buffer_arg is not None
|
||||
else None
|
||||
),
|
||||
(
|
||||
_view_3d(fused_set_kv_buffer_arg.v_buffer)
|
||||
if fused_set_kv_buffer_arg is not None
|
||||
else None
|
||||
),
|
||||
(
|
||||
fused_set_kv_buffer_arg.cache_loc
|
||||
if fused_set_kv_buffer_arg is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
0
sgl-kernel/python/sgl_kernel/testing/__init__.py
Normal file
0
sgl-kernel/python/sgl_kernel/testing/__init__.py
Normal file
217
sgl-kernel/python/sgl_kernel/testing/rotary_embedding.py
Normal file
217
sgl-kernel/python/sgl_kernel/testing/rotary_embedding.py
Normal file
@@ -0,0 +1,217 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace
|
||||
|
||||
|
||||
# vLLM torch native
|
||||
def _apply_rotary_emb(
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
is_neox_style: bool,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: [num_tokens, num_heads, head_size]
|
||||
cos: [num_tokens, head_size // 2]
|
||||
sin: [num_tokens, head_size // 2]
|
||||
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
|
||||
positional embeddings.
|
||||
"""
|
||||
cos = cos.unsqueeze(-2).to(x.dtype)
|
||||
sin = sin.unsqueeze(-2).to(x.dtype)
|
||||
if is_neox_style:
|
||||
x1, x2 = torch.chunk(x, 2, dim=-1)
|
||||
else:
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
o1 = x1 * cos - x2 * sin
|
||||
o2 = x2 * cos + x1 * sin
|
||||
if is_neox_style:
|
||||
return torch.cat((o1, o2), dim=-1)
|
||||
else:
|
||||
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
||||
|
||||
|
||||
class RotaryEmbedding(torch.nn.Module):
|
||||
# Reference: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_size = head_size
|
||||
self.rotary_dim = rotary_dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.is_neox_style = is_neox_style
|
||||
self.dtype = dtype
|
||||
|
||||
cache = self._compute_cos_sin_cache()
|
||||
self.cos_sin_cache: torch.Tensor
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||
inv_freq = 1.0 / (
|
||||
base
|
||||
** (
|
||||
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
|
||||
)
|
||||
)
|
||||
return inv_freq
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
"""Compute the cos and sin cache."""
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
|
||||
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""A PyTorch-native implementation of forward()."""
|
||||
if offsets is not None:
|
||||
positions = positions + offsets
|
||||
|
||||
positions = positions.flatten()
|
||||
num_tokens = positions.shape[0]
|
||||
cos_sin = self.cos_sin_cache.index_select(0, positions)
|
||||
|
||||
# Modification: float32 is required for the rotary embedding to work correctly
|
||||
query = query.to(torch.float32)
|
||||
key = key.to(torch.float32)
|
||||
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
|
||||
query_shape = query.shape
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., : self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim :]
|
||||
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., : self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim :]
|
||||
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
|
||||
# Modification: convert to the correct dtype
|
||||
query = query.to(self.dtype)
|
||||
key = key.to(self.dtype)
|
||||
return query, key
|
||||
|
||||
|
||||
class FlashInferRotaryEmbedding(RotaryEmbedding):
|
||||
def forward_cuda(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
apply_rope_with_cos_sin_cache_inplace(
|
||||
positions=positions,
|
||||
query=query,
|
||||
key=key,
|
||||
fused_set_kv_buffer_arg=fused_set_kv_buffer_arg,
|
||||
head_size=self.head_size,
|
||||
cos_sin_cache=self.cos_sin_cache,
|
||||
is_neox=self.is_neox_style,
|
||||
)
|
||||
|
||||
return query, key
|
||||
|
||||
|
||||
class MHATokenToKVPool:
|
||||
KV_POOL_SIZE = 16384
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_num: int,
|
||||
head_dim: int,
|
||||
):
|
||||
self.head_num = head_num
|
||||
self.head_dim = head_dim
|
||||
self.size = MHATokenToKVPool.KV_POOL_SIZE
|
||||
self.page_size = 1
|
||||
self.store_dtype = torch.bfloat16
|
||||
self.device = "cuda"
|
||||
self.layer_num = 1
|
||||
self.start_layer = 0
|
||||
self._create_buffers()
|
||||
|
||||
def _create_buffers(self):
|
||||
self.k_buffer = [
|
||||
torch.zeros(
|
||||
(self.size + self.page_size, self.head_num, self.head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
self.v_buffer = [
|
||||
torch.zeros(
|
||||
(self.size + self.page_size, self.head_num, self.head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
loc: torch.Tensor,
|
||||
cache_k: torch.Tensor,
|
||||
cache_v: torch.Tensor,
|
||||
):
|
||||
layer_id = 0
|
||||
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
|
||||
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
||||
|
||||
|
||||
def create_inputs(
|
||||
head_size: int,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
device,
|
||||
dtype: torch.dtype,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
):
|
||||
pos_ids = torch.arange(seq_len, device=device).repeat(batch_size)
|
||||
query = torch.randn(
|
||||
batch_size * seq_len, num_q_heads * head_size, dtype=dtype, device=device
|
||||
)
|
||||
key = torch.randn(
|
||||
batch_size * seq_len, num_kv_heads * head_size, dtype=dtype, device=device
|
||||
)
|
||||
value = torch.randn(
|
||||
batch_size * seq_len, num_kv_heads * head_size, dtype=dtype, device=device
|
||||
)
|
||||
out_cache_loc = torch.randperm(
|
||||
MHATokenToKVPool.KV_POOL_SIZE, dtype=torch.int64, device=device
|
||||
)[: batch_size * seq_len].clone()
|
||||
|
||||
return dict(
|
||||
pos_ids=pos_ids, query=query, key=key, value=value, out_cache_loc=out_cache_loc
|
||||
)
|
||||
@@ -2,153 +2,51 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
||||
|
||||
|
||||
# vLLM torch native
|
||||
def _apply_rotary_emb(
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
is_neox_style: bool,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: [num_tokens, num_heads, head_size]
|
||||
cos: [num_tokens, head_size // 2]
|
||||
sin: [num_tokens, head_size // 2]
|
||||
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
|
||||
positional embeddings.
|
||||
"""
|
||||
cos = cos.unsqueeze(-2).to(x.dtype)
|
||||
sin = sin.unsqueeze(-2).to(x.dtype)
|
||||
if is_neox_style:
|
||||
x1, x2 = torch.chunk(x, 2, dim=-1)
|
||||
else:
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
o1 = x1 * cos - x2 * sin
|
||||
o2 = x2 * cos + x1 * sin
|
||||
if is_neox_style:
|
||||
return torch.cat((o1, o2), dim=-1)
|
||||
else:
|
||||
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
||||
|
||||
|
||||
class RotaryEmbedding(torch.nn.Module):
|
||||
# Reference: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_size = head_size
|
||||
self.rotary_dim = rotary_dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.is_neox_style = is_neox_style
|
||||
self.dtype = dtype
|
||||
|
||||
cache = self._compute_cos_sin_cache()
|
||||
self.cos_sin_cache: torch.Tensor
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||
inv_freq = 1.0 / (
|
||||
base
|
||||
** (
|
||||
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
|
||||
)
|
||||
)
|
||||
return inv_freq
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
"""Compute the cos and sin cache."""
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
|
||||
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""A PyTorch-native implementation of forward()."""
|
||||
if offsets is not None:
|
||||
positions = positions + offsets
|
||||
|
||||
positions = positions.flatten()
|
||||
num_tokens = positions.shape[0]
|
||||
cos_sin = self.cos_sin_cache.index_select(0, positions)
|
||||
|
||||
# Modification: float32 is required for the rotary embedding to work correctly
|
||||
query = query.to(torch.float32)
|
||||
key = key.to(torch.float32)
|
||||
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
|
||||
query_shape = query.shape
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., : self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim :]
|
||||
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., : self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim :]
|
||||
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
|
||||
# Modification: convert to the correct dtype
|
||||
query = query.to(self.dtype)
|
||||
key = key.to(self.dtype)
|
||||
return query, key
|
||||
|
||||
|
||||
class FlashInferRotaryEmbedding(RotaryEmbedding):
|
||||
def forward_cuda(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
apply_rope_with_cos_sin_cache_inplace(
|
||||
positions=positions,
|
||||
query=query,
|
||||
key=key,
|
||||
head_size=self.head_size,
|
||||
cos_sin_cache=self.cos_sin_cache,
|
||||
is_neox=self.is_neox_style,
|
||||
)
|
||||
|
||||
return query, key
|
||||
from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace
|
||||
from sgl_kernel.testing.rotary_embedding import (
|
||||
FlashInferRotaryEmbedding,
|
||||
MHATokenToKVPool,
|
||||
RotaryEmbedding,
|
||||
create_inputs,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, batch_size, seq_len, num_q_heads, num_kv_heads",
|
||||
"head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, batch_size, seq_len, num_q_heads, num_kv_heads, save_kv_cache",
|
||||
[
|
||||
(64, 64, 32, 8000, True, torch.bfloat16, "cuda", 32, 32, 1, 1),
|
||||
(256, 128, 4096, 10000, True, torch.bfloat16, "cuda", 2, 512, 4, 2),
|
||||
(512, 128, 311, 10000, True, torch.bfloat16, "cuda", 3, 39, 4, 2),
|
||||
(128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 32, 8),
|
||||
(128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 16, 4),
|
||||
(512, 128, 311, 10000, False, torch.bfloat16, "cuda", 3, 39, 4, 2),
|
||||
# GPT-OSS cases
|
||||
*[
|
||||
(
|
||||
64,
|
||||
64,
|
||||
4096,
|
||||
8000,
|
||||
True,
|
||||
torch.bfloat16,
|
||||
"cuda",
|
||||
batch_size,
|
||||
seq_len,
|
||||
64,
|
||||
8,
|
||||
save_kv_cache,
|
||||
)
|
||||
for batch_size, seq_len in (
|
||||
(1, 1),
|
||||
(32, 1),
|
||||
(128, 1),
|
||||
(512, 1),
|
||||
(2, 512),
|
||||
(4, 4096),
|
||||
)
|
||||
for save_kv_cache in (False, True)
|
||||
],
|
||||
# Other cases
|
||||
(64, 64, 32, 8000, True, torch.bfloat16, "cuda", 32, 32, 1, 1, False),
|
||||
(256, 128, 4096, 10000, True, torch.bfloat16, "cuda", 2, 512, 4, 2, False),
|
||||
(512, 128, 311, 10000, True, torch.bfloat16, "cuda", 3, 39, 4, 2, False),
|
||||
(128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 32, 8, False),
|
||||
(128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 16, 4, False),
|
||||
(512, 128, 311, 10000, False, torch.bfloat16, "cuda", 3, 39, 4, 2, False),
|
||||
],
|
||||
)
|
||||
def test_correctness(
|
||||
@@ -163,34 +61,77 @@ def test_correctness(
|
||||
seq_len: int,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
save_kv_cache: bool,
|
||||
):
|
||||
rope_ref = RotaryEmbedding(
|
||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
).to(device)
|
||||
rope_flashinfer = FlashInferRotaryEmbedding(
|
||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
).to(device)
|
||||
|
||||
pos_ids = torch.arange(seq_len, device=device).repeat(batch_size)
|
||||
query = torch.randn(
|
||||
batch_size * seq_len, num_q_heads * head_size, dtype=dtype, device=device
|
||||
)
|
||||
key = torch.randn(
|
||||
batch_size * seq_len, num_kv_heads * head_size, dtype=dtype, device=device
|
||||
config = dict(
|
||||
head_size=head_size,
|
||||
rotary_dim=rotary_dim,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
base=base,
|
||||
is_neox_style=is_neox_style,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
query_ref, key_ref = query.clone(), key.clone()
|
||||
query_flashinfer, key_flashinfer = query.clone(), key.clone()
|
||||
rope_ref = RotaryEmbedding(**config).to(device)
|
||||
rope_flashinfer = FlashInferRotaryEmbedding(**config).to(device)
|
||||
|
||||
inputs = create_inputs(
|
||||
head_size=head_size,
|
||||
batch_size=batch_size,
|
||||
seq_len=seq_len,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
num_q_heads=num_q_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
)
|
||||
|
||||
if save_kv_cache:
|
||||
pool_ref = MHATokenToKVPool(head_num=num_kv_heads, head_dim=head_size)
|
||||
pool_flashinfer = MHATokenToKVPool(head_num=num_kv_heads, head_dim=head_size)
|
||||
|
||||
query_ref, key_ref = inputs["query"].clone(), inputs["key"].clone()
|
||||
query_flashinfer, key_flashinfer = inputs["query"].clone(), inputs["key"].clone()
|
||||
|
||||
query_ref_out, key_ref_out = rope_ref.forward_native(
|
||||
inputs["pos_ids"], query_ref, key_ref
|
||||
)
|
||||
if save_kv_cache:
|
||||
pool_ref.set_kv_buffer(
|
||||
loc=inputs["out_cache_loc"],
|
||||
cache_k=key_ref_out.view(-1, num_kv_heads, head_size),
|
||||
cache_v=inputs["value"].view(-1, num_kv_heads, head_size),
|
||||
)
|
||||
|
||||
query_ref_out, key_ref_out = rope_ref.forward_native(pos_ids, query_ref, key_ref)
|
||||
query_flashinfer_out, key_flashinfer_out = rope_flashinfer.forward_cuda(
|
||||
pos_ids, query_flashinfer, key_flashinfer
|
||||
inputs["pos_ids"],
|
||||
query_flashinfer,
|
||||
key_flashinfer,
|
||||
fused_set_kv_buffer_arg=(
|
||||
FusedSetKVBufferArg(
|
||||
value=inputs["value"],
|
||||
k_buffer=pool_flashinfer.k_buffer[0].view(-1, num_kv_heads * head_size),
|
||||
v_buffer=pool_flashinfer.v_buffer[0].view(-1, num_kv_heads * head_size),
|
||||
k_scale=None,
|
||||
v_scale=None,
|
||||
cache_loc=inputs["out_cache_loc"],
|
||||
)
|
||||
if save_kv_cache
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
query_ref_out, query_flashinfer_out, atol=1e-2, rtol=1e-2
|
||||
)
|
||||
torch.testing.assert_close(key_ref_out, key_flashinfer_out, atol=1e-2, rtol=1e-2)
|
||||
if save_kv_cache:
|
||||
for field in ["k_buffer", "v_buffer"]:
|
||||
x_ref = getattr(pool_ref, field)[0]
|
||||
x_flashinfer = getattr(pool_flashinfer, field)[0]
|
||||
torch.testing.assert_close(x_ref, x_flashinfer, atol=1e-2, rtol=1e-2)
|
||||
nonzero_ref = x_ref != 0
|
||||
nonzero_flashinfer = x_ref != 0
|
||||
assert torch.all(nonzero_ref == nonzero_flashinfer)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user