Add CPU optimized kernels for topk and rope fusions (#6456)

This commit is contained in:
jianan-gu
2025-06-03 08:37:34 +08:00
committed by GitHub
parent ff91474825
commit ff00895c46
7 changed files with 829 additions and 98 deletions

View File

@@ -4,7 +4,10 @@ import sgl_kernel
import torch
from utils import precision
from sglang.srt.layers.rotary_embedding import DeepseekScalingRotaryEmbedding
from sglang.srt.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding,
RotaryEmbedding,
)
from sglang.test.test_utils import CustomTestCase
@@ -62,10 +65,13 @@ class TestROPE(CustomTestCase):
)
# fused rope kernel
q_pe_clone, k_pe_clone = (
torch.ops.sgl_kernel.rotary_position_embedding_cpu(
positions, q_pe_clone, k_pe_clone, cos_sin_cache
)
q_pe_clone, k_pe_clone = torch.ops.sgl_kernel.rotary_embedding_cpu(
positions,
q_pe_clone,
k_pe_clone,
rope.head_size,
cos_sin_cache,
False,
)
atol = rtol = precision[q_pe.dtype]
@@ -73,6 +79,98 @@ class TestROPE(CustomTestCase):
self.assertTrue(torch.allclose(k_pe, k_pe_clone, atol=atol, rtol=rtol))
torch.testing.assert_close(k_pe, k_pe_clone)
def test_origin_rope(self):
def single_test(
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
device: str,
batch_size: int,
seq_len: int,
num_q_heads: int,
num_kv_heads: int,
):
torch.manual_seed(100)
rope_ref = RotaryEmbedding(
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
).to(device)
pos_ids = torch.arange(seq_len, device=device).repeat(batch_size)
query = torch.randn(
batch_size * seq_len,
num_q_heads * head_size,
dtype=dtype,
device=device,
)
key = torch.randn(
batch_size * seq_len,
num_kv_heads * head_size,
dtype=dtype,
device=device,
)
query_ref, key_ref = query.clone(), key.clone()
query_cpu, key_cpu = query.clone(), key.clone()
query_ref_out, key_ref_out = rope_ref.forward_native(
pos_ids, query_ref, key_ref
)
query_cpu_out, key_cpu_out = torch.ops.sgl_kernel.rotary_embedding_cpu(
pos_ids,
query_cpu,
key_cpu,
rope_ref.head_size,
rope_ref.cos_sin_cache.to(query.dtype),
rope_ref.is_neox_style,
)
torch.testing.assert_close(
query_ref_out, query_cpu_out, atol=1e-2, rtol=1e-2
)
torch.testing.assert_close(key_ref_out, key_cpu_out, atol=1e-2, rtol=1e-2)
test_config = [
(64, 64, 32, 8000, True, torch.bfloat16, "cpu", 32, 32, 1, 1),
(256, 128, 4096, 10000, True, torch.bfloat16, "cpu", 2, 512, 32, 8),
(512, 128, 311, 10000, True, torch.bfloat16, "cpu", 3, 39, 4, 2),
(128, 128, 2048, 10000, False, torch.bfloat16, "cpu", 2, 512, 32, 8),
(128, 128, 2048, 10000, False, torch.bfloat16, "cpu", 2, 512, 16, 4),
(512, 128, 311, 10000, False, torch.bfloat16, "cpu", 3, 39, 4, 2),
]
for (
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
device,
batch_size,
seq_len,
num_q_heads,
num_kv_heads,
) in test_config:
single_test(
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
device,
batch_size,
seq_len,
num_q_heads,
num_kv_heads,
)
if __name__ == "__main__":
unittest.main()