Files
sglang/sgl-kernel/tests/test_rotary_embedding.py
2025-01-20 20:58:51 +08:00

119 lines
3.1 KiB
Python

from typing import Optional, Tuple
import torch
from vllm.model_executor.layers.rotary_embedding import (
RotaryEmbedding as VLLMRotaryEmbedding,
)
class SGLRotaryEmbedding(VLLMRotaryEmbedding):
def forward_cuda(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
from sgl_kernel import rotary_embedding
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
)
return query, key
# Compare the output of SGLRotaryEmbedding's forward_cuda with VLLMRotaryEmbedding's forward_native
def test_rotary_embedding():
# Test case 1: FP32
def run_test(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
batch_size,
seq_len,
num_heads,
test_name,
):
print(f"\nRunning {test_name}...")
# Initialize both implementations
sgl_rope = SGLRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, dtype
).to("cuda")
vllm_rope = VLLMRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, dtype
).to("cuda")
# Regular forward pass
positions = torch.arange(seq_len, device="cuda").repeat(batch_size)
query = torch.randn(
batch_size * seq_len, num_heads * head_size, device="cuda", dtype=dtype
)
key = torch.randn(
batch_size * seq_len, num_heads * head_size, device="cuda", dtype=dtype
)
# Make copies for both implementations
query_sgl = query.clone()
key_sgl = key.clone()
query_vllm = query.clone()
key_vllm = key.clone()
# Run both implementations
query_sgl_out, key_sgl_out = sgl_rope.forward_cuda(
positions, query_sgl, key_sgl
)
query_vllm_out, key_vllm_out = vllm_rope.forward_native(
positions, query_vllm, key_vllm
)
# Compare outputs
torch.testing.assert_close(query_sgl_out, query_vllm_out, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(key_sgl_out, key_vllm_out, rtol=1e-3, atol=1e-3)
print(f"{test_name} passed!")
# Test Case 1: FP32 with larger dimensions
run_test(
head_size=128,
rotary_dim=64,
max_position=4096,
base=10000,
is_neox_style=True,
dtype=torch.float32,
batch_size=4,
seq_len=32,
num_heads=8,
test_name="FP32 Test",
)
# Test Case 2: BF16 with smaller dimensions
run_test(
head_size=64,
rotary_dim=32,
max_position=2048,
base=8000,
is_neox_style=True,
dtype=torch.bfloat16,
batch_size=2,
seq_len=16,
num_heads=4,
test_name="BF16 Test",
)
if __name__ == "__main__":
test_rotary_embedding()