[kernel] port rope cuda kernel to sgl-kernel (#2993)
Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
118
sgl-kernel/tests/test_rotary_embedding.py
Normal file
118
sgl-kernel/tests/test_rotary_embedding.py
Normal file
@@ -0,0 +1,118 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
RotaryEmbedding as VLLMRotaryEmbedding,
|
||||
)
|
||||
|
||||
|
||||
class SGLRotaryEmbedding(VLLMRotaryEmbedding):
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
from sgl_kernel import rotary_embedding
|
||||
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
|
||||
|
||||
rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style,
|
||||
)
|
||||
return query, key
|
||||
|
||||
|
||||
# Compare the output of SGLRotaryEmbedding's forward_cuda with VLLMRotaryEmbedding's forward_native
|
||||
|
||||
|
||||
def test_rotary_embedding():
|
||||
# Test case 1: FP32
|
||||
def run_test(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
batch_size,
|
||||
seq_len,
|
||||
num_heads,
|
||||
test_name,
|
||||
):
|
||||
print(f"\nRunning {test_name}...")
|
||||
# Initialize both implementations
|
||||
sgl_rope = SGLRotaryEmbedding(
|
||||
head_size, rotary_dim, max_position, base, is_neox_style, dtype
|
||||
).to("cuda")
|
||||
vllm_rope = VLLMRotaryEmbedding(
|
||||
head_size, rotary_dim, max_position, base, is_neox_style, dtype
|
||||
).to("cuda")
|
||||
|
||||
# Regular forward pass
|
||||
positions = torch.arange(seq_len, device="cuda").repeat(batch_size)
|
||||
query = torch.randn(
|
||||
batch_size * seq_len, num_heads * head_size, device="cuda", dtype=dtype
|
||||
)
|
||||
key = torch.randn(
|
||||
batch_size * seq_len, num_heads * head_size, device="cuda", dtype=dtype
|
||||
)
|
||||
|
||||
# Make copies for both implementations
|
||||
query_sgl = query.clone()
|
||||
key_sgl = key.clone()
|
||||
query_vllm = query.clone()
|
||||
key_vllm = key.clone()
|
||||
|
||||
# Run both implementations
|
||||
query_sgl_out, key_sgl_out = sgl_rope.forward_cuda(
|
||||
positions, query_sgl, key_sgl
|
||||
)
|
||||
query_vllm_out, key_vllm_out = vllm_rope.forward_native(
|
||||
positions, query_vllm, key_vllm
|
||||
)
|
||||
|
||||
# Compare outputs
|
||||
torch.testing.assert_close(query_sgl_out, query_vllm_out, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(key_sgl_out, key_vllm_out, rtol=1e-3, atol=1e-3)
|
||||
|
||||
print(f"{test_name} passed!")
|
||||
|
||||
# Test Case 1: FP32 with larger dimensions
|
||||
run_test(
|
||||
head_size=128,
|
||||
rotary_dim=64,
|
||||
max_position=4096,
|
||||
base=10000,
|
||||
is_neox_style=True,
|
||||
dtype=torch.float32,
|
||||
batch_size=4,
|
||||
seq_len=32,
|
||||
num_heads=8,
|
||||
test_name="FP32 Test",
|
||||
)
|
||||
|
||||
# Test Case 2: BF16 with smaller dimensions
|
||||
run_test(
|
||||
head_size=64,
|
||||
rotary_dim=32,
|
||||
max_position=2048,
|
||||
base=8000,
|
||||
is_neox_style=True,
|
||||
dtype=torch.bfloat16,
|
||||
batch_size=2,
|
||||
seq_len=16,
|
||||
num_heads=4,
|
||||
test_name="BF16 Test",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_rotary_embedding()
|
||||
Reference in New Issue
Block a user