[kernel] port rope cuda kernel to sgl-kernel (#2993)
Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -222,3 +222,6 @@ work_dirs/
|
||||
compile_commands.json
|
||||
|
||||
*.iml
|
||||
|
||||
# VSCode
|
||||
.vscode
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "sgl-kernel"
|
||||
version = "0.0.2.post14"
|
||||
version = "0.0.2.post15"
|
||||
description = "Kernel Library for SGLang"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
|
||||
@@ -53,6 +53,7 @@ ext_modules = [
|
||||
"src/sgl-kernel/csrc/int8_gemm_kernel.cu",
|
||||
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu",
|
||||
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
|
||||
"src/sgl-kernel/csrc/rotary_embedding.cu",
|
||||
],
|
||||
include_dirs=include_dirs,
|
||||
extra_compile_args={
|
||||
|
||||
@@ -6,6 +6,7 @@ from sgl_kernel.ops import (
|
||||
int8_scaled_mm,
|
||||
moe_align_block_size,
|
||||
register_graph_buffers,
|
||||
rotary_embedding,
|
||||
sampling_scaling_penalties,
|
||||
)
|
||||
|
||||
@@ -18,4 +19,5 @@ __all__ = [
|
||||
"sampling_scaling_penalties",
|
||||
"get_graph_buffer_ipc_meta",
|
||||
"register_graph_buffers",
|
||||
"rotary_embedding",
|
||||
]
|
||||
|
||||
119
sgl-kernel/src/sgl-kernel/csrc/rotary_embedding.cu
Normal file
119
sgl-kernel/src/sgl-kernel/csrc/rotary_embedding.cu
Normal file
@@ -0,0 +1,119 @@
|
||||
// Reference: https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding_kernels.cu
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
template <typename scalar_t, bool IS_NEOX>
|
||||
inline __device__ void apply_token_rotary_embedding(scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr,
|
||||
const scalar_t* __restrict__ sin_ptr, int rot_offset,
|
||||
int embed_dim) {
|
||||
int x_index, y_index;
|
||||
scalar_t cos, sin;
|
||||
if (IS_NEOX) {
|
||||
// GPT-NeoX style rotary embedding.
|
||||
x_index = rot_offset;
|
||||
y_index = embed_dim + rot_offset;
|
||||
cos = __ldg(cos_ptr + x_index);
|
||||
sin = __ldg(sin_ptr + x_index);
|
||||
} else {
|
||||
// GPT-J style rotary embedding.
|
||||
x_index = 2 * rot_offset;
|
||||
y_index = 2 * rot_offset + 1;
|
||||
cos = __ldg(cos_ptr + x_index / 2);
|
||||
sin = __ldg(sin_ptr + x_index / 2);
|
||||
}
|
||||
|
||||
const scalar_t x = arr[x_index];
|
||||
const scalar_t y = arr[y_index];
|
||||
arr[x_index] = x * cos - y * sin;
|
||||
arr[y_index] = y * cos + x * sin;
|
||||
}
|
||||
|
||||
template <typename scalar_t, bool IS_NEOX>
|
||||
inline __device__ void apply_rotary_embedding(scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
|
||||
// head_size] or [num_tokens, num_heads,
|
||||
// head_size]
|
||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
|
||||
// head_size] or [num_tokens, num_kv_heads,
|
||||
// head_size]
|
||||
const scalar_t* cache_ptr, const int head_size, const int num_heads,
|
||||
const int num_kv_heads, const int rot_dim, const int token_idx,
|
||||
const int64_t query_stride, const int64_t key_stride) {
|
||||
const int embed_dim = rot_dim / 2;
|
||||
const scalar_t* cos_ptr = cache_ptr;
|
||||
const scalar_t* sin_ptr = cache_ptr + embed_dim;
|
||||
|
||||
const int nq = num_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
|
||||
const int head_idx = i / embed_dim;
|
||||
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
|
||||
const int rot_offset = i % embed_dim;
|
||||
apply_token_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
||||
}
|
||||
|
||||
const int nk = num_kv_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
|
||||
const int head_idx = i / embed_dim;
|
||||
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
||||
const int rot_offset = i % embed_dim;
|
||||
apply_token_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, bool IS_NEOX>
|
||||
__global__ void rotary_embedding_kernel(const int64_t* __restrict__ positions, // [batch_size, seq_len] or
|
||||
// [num_tokens]
|
||||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
|
||||
// head_size] or [num_tokens, num_heads,
|
||||
// head_size]
|
||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
|
||||
// head_size] or [num_tokens, num_kv_heads,
|
||||
// head_size]
|
||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
||||
// 2]
|
||||
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
||||
const int num_heads, const int num_kv_heads, const int head_size) {
|
||||
// Each thread block is responsible for one token.
|
||||
const int token_idx = blockIdx.x;
|
||||
int64_t pos = positions[token_idx];
|
||||
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||
|
||||
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
|
||||
token_idx, query_stride, key_stride);
|
||||
}
|
||||
|
||||
void rotary_embedding(torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
||||
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
|
||||
// [num_tokens, num_heads * head_size]
|
||||
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
|
||||
// [num_tokens, num_kv_heads * head_size]
|
||||
int64_t head_size,
|
||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||
bool is_neox) {
|
||||
int64_t num_tokens = query.numel() / query.size(-1);
|
||||
int rot_dim = cos_sin_cache.size(1);
|
||||
int num_heads = query.size(-1) / head_size;
|
||||
int num_kv_heads = key.size(-1) / head_size;
|
||||
int64_t query_stride = query.stride(-2);
|
||||
int64_t key_stride = key.stride(-2);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::BFloat16, at::ScalarType::Half, query.scalar_type(), "rotary_embedding", [&] {
|
||||
if (is_neox) {
|
||||
rotary_embedding_kernel<scalar_t, true>
|
||||
<<<grid, block, 0, stream>>>(positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), rot_dim,
|
||||
query_stride, key_stride, num_heads, num_kv_heads, head_size);
|
||||
} else {
|
||||
rotary_embedding_kernel<scalar_t, false>
|
||||
<<<grid, block, 0, stream>>>(positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), rot_dim,
|
||||
query_stride, key_stride, num_heads, num_kv_heads, head_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -26,6 +26,10 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma
|
||||
const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
|
||||
const c10::optional<torch::Tensor>& bias);
|
||||
|
||||
// rotary embedding
|
||||
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Tensor& key, int64_t head_size,
|
||||
torch::Tensor& cos_sin_cache, bool is_neox);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
// trt_reduce
|
||||
m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)");
|
||||
@@ -39,4 +43,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("sampling_scaling_penalties", &sampling_scaling_penalties, "Sampling scaling penalties (CUDA)");
|
||||
// int8_scaled_mm
|
||||
m.def("int8_scaled_mm", &int8_scaled_mm, "INT8 scaled matmul (CUDA)");
|
||||
// rotary embedding
|
||||
m.def("rotary_embedding", &rotary_embedding, "Rotary Embedding (CUDA)");
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ from sgl_kernel.ops._kernels import init_custom_ar as _init_custom_ar
|
||||
from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm
|
||||
from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size
|
||||
from sgl_kernel.ops._kernels import register_graph_buffers as _register_graph_buffers
|
||||
from sgl_kernel.ops._kernels import rotary_embedding as _rotary_embedding
|
||||
from sgl_kernel.ops._kernels import (
|
||||
sampling_scaling_penalties as _sampling_scaling_penalties,
|
||||
)
|
||||
@@ -71,3 +72,7 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
||||
out_dtype,
|
||||
bias,
|
||||
)
|
||||
|
||||
|
||||
def rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox):
|
||||
return _rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox)
|
||||
|
||||
118
sgl-kernel/tests/test_rotary_embedding.py
Normal file
118
sgl-kernel/tests/test_rotary_embedding.py
Normal file
@@ -0,0 +1,118 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
RotaryEmbedding as VLLMRotaryEmbedding,
|
||||
)
|
||||
|
||||
|
||||
class SGLRotaryEmbedding(VLLMRotaryEmbedding):
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
from sgl_kernel import rotary_embedding
|
||||
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
|
||||
|
||||
rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style,
|
||||
)
|
||||
return query, key
|
||||
|
||||
|
||||
# Compare the output of SGLRotaryEmbedding's forward_cuda with VLLMRotaryEmbedding's forward_native
|
||||
|
||||
|
||||
def test_rotary_embedding():
|
||||
# Test case 1: FP32
|
||||
def run_test(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
batch_size,
|
||||
seq_len,
|
||||
num_heads,
|
||||
test_name,
|
||||
):
|
||||
print(f"\nRunning {test_name}...")
|
||||
# Initialize both implementations
|
||||
sgl_rope = SGLRotaryEmbedding(
|
||||
head_size, rotary_dim, max_position, base, is_neox_style, dtype
|
||||
).to("cuda")
|
||||
vllm_rope = VLLMRotaryEmbedding(
|
||||
head_size, rotary_dim, max_position, base, is_neox_style, dtype
|
||||
).to("cuda")
|
||||
|
||||
# Regular forward pass
|
||||
positions = torch.arange(seq_len, device="cuda").repeat(batch_size)
|
||||
query = torch.randn(
|
||||
batch_size * seq_len, num_heads * head_size, device="cuda", dtype=dtype
|
||||
)
|
||||
key = torch.randn(
|
||||
batch_size * seq_len, num_heads * head_size, device="cuda", dtype=dtype
|
||||
)
|
||||
|
||||
# Make copies for both implementations
|
||||
query_sgl = query.clone()
|
||||
key_sgl = key.clone()
|
||||
query_vllm = query.clone()
|
||||
key_vllm = key.clone()
|
||||
|
||||
# Run both implementations
|
||||
query_sgl_out, key_sgl_out = sgl_rope.forward_cuda(
|
||||
positions, query_sgl, key_sgl
|
||||
)
|
||||
query_vllm_out, key_vllm_out = vllm_rope.forward_native(
|
||||
positions, query_vllm, key_vllm
|
||||
)
|
||||
|
||||
# Compare outputs
|
||||
torch.testing.assert_close(query_sgl_out, query_vllm_out, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(key_sgl_out, key_vllm_out, rtol=1e-3, atol=1e-3)
|
||||
|
||||
print(f"{test_name} passed!")
|
||||
|
||||
# Test Case 1: FP32 with larger dimensions
|
||||
run_test(
|
||||
head_size=128,
|
||||
rotary_dim=64,
|
||||
max_position=4096,
|
||||
base=10000,
|
||||
is_neox_style=True,
|
||||
dtype=torch.float32,
|
||||
batch_size=4,
|
||||
seq_len=32,
|
||||
num_heads=8,
|
||||
test_name="FP32 Test",
|
||||
)
|
||||
|
||||
# Test Case 2: BF16 with smaller dimensions
|
||||
run_test(
|
||||
head_size=64,
|
||||
rotary_dim=32,
|
||||
max_position=2048,
|
||||
base=8000,
|
||||
is_neox_style=True,
|
||||
dtype=torch.bfloat16,
|
||||
batch_size=2,
|
||||
seq_len=16,
|
||||
num_heads=4,
|
||||
test_name="BF16 Test",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_rotary_embedding()
|
||||
Reference in New Issue
Block a user