Add CPU optimized kernels for topk and rope fusions (#6456)
This commit is contained in:
@@ -4,7 +4,10 @@ import sgl_kernel
|
||||
import torch
|
||||
from utils import precision
|
||||
|
||||
from sglang.srt.layers.rotary_embedding import DeepseekScalingRotaryEmbedding
|
||||
from sglang.srt.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding,
|
||||
RotaryEmbedding,
|
||||
)
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
@@ -62,10 +65,13 @@ class TestROPE(CustomTestCase):
|
||||
)
|
||||
|
||||
# fused rope kernel
|
||||
q_pe_clone, k_pe_clone = (
|
||||
torch.ops.sgl_kernel.rotary_position_embedding_cpu(
|
||||
positions, q_pe_clone, k_pe_clone, cos_sin_cache
|
||||
)
|
||||
q_pe_clone, k_pe_clone = torch.ops.sgl_kernel.rotary_embedding_cpu(
|
||||
positions,
|
||||
q_pe_clone,
|
||||
k_pe_clone,
|
||||
rope.head_size,
|
||||
cos_sin_cache,
|
||||
False,
|
||||
)
|
||||
|
||||
atol = rtol = precision[q_pe.dtype]
|
||||
@@ -73,6 +79,98 @@ class TestROPE(CustomTestCase):
|
||||
self.assertTrue(torch.allclose(k_pe, k_pe_clone, atol=atol, rtol=rtol))
|
||||
torch.testing.assert_close(k_pe, k_pe_clone)
|
||||
|
||||
def test_origin_rope(self):
|
||||
def single_test(
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
):
|
||||
torch.manual_seed(100)
|
||||
rope_ref = RotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
).to(device)
|
||||
pos_ids = torch.arange(seq_len, device=device).repeat(batch_size)
|
||||
query = torch.randn(
|
||||
batch_size * seq_len,
|
||||
num_q_heads * head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
key = torch.randn(
|
||||
batch_size * seq_len,
|
||||
num_kv_heads * head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
query_ref, key_ref = query.clone(), key.clone()
|
||||
query_cpu, key_cpu = query.clone(), key.clone()
|
||||
|
||||
query_ref_out, key_ref_out = rope_ref.forward_native(
|
||||
pos_ids, query_ref, key_ref
|
||||
)
|
||||
query_cpu_out, key_cpu_out = torch.ops.sgl_kernel.rotary_embedding_cpu(
|
||||
pos_ids,
|
||||
query_cpu,
|
||||
key_cpu,
|
||||
rope_ref.head_size,
|
||||
rope_ref.cos_sin_cache.to(query.dtype),
|
||||
rope_ref.is_neox_style,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
query_ref_out, query_cpu_out, atol=1e-2, rtol=1e-2
|
||||
)
|
||||
torch.testing.assert_close(key_ref_out, key_cpu_out, atol=1e-2, rtol=1e-2)
|
||||
|
||||
test_config = [
|
||||
(64, 64, 32, 8000, True, torch.bfloat16, "cpu", 32, 32, 1, 1),
|
||||
(256, 128, 4096, 10000, True, torch.bfloat16, "cpu", 2, 512, 32, 8),
|
||||
(512, 128, 311, 10000, True, torch.bfloat16, "cpu", 3, 39, 4, 2),
|
||||
(128, 128, 2048, 10000, False, torch.bfloat16, "cpu", 2, 512, 32, 8),
|
||||
(128, 128, 2048, 10000, False, torch.bfloat16, "cpu", 2, 512, 16, 4),
|
||||
(512, 128, 311, 10000, False, torch.bfloat16, "cpu", 3, 39, 4, 2),
|
||||
]
|
||||
|
||||
for (
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
device,
|
||||
batch_size,
|
||||
seq_len,
|
||||
num_q_heads,
|
||||
num_kv_heads,
|
||||
) in test_config:
|
||||
single_test(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
device,
|
||||
batch_size,
|
||||
seq_len,
|
||||
num_q_heads,
|
||||
num_kv_heads,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user