CPU: map changes from developing branch in sgl-kernel (#6833)

Co-authored-by: mingfeima <mingfei.ma@intel.com>
This commit is contained in:
YanbingJiang
2025-06-10 16:08:15 +08:00
committed by GitHub
parent 81372f3bef
commit fcde67b016
20 changed files with 1321 additions and 321 deletions

155
test/srt/cpu/test_mla.py Normal file
View File

@@ -0,0 +1,155 @@
import itertools
import unittest
import sgl_kernel
import torch
from torch.nn.functional import scaled_dot_product_attention
from utils import precision
from sglang.test.test_utils import CustomTestCase
class TestMLA(CustomTestCase):
def _run_sdpa_forward_decode(
self,
query: torch.Tensor,
output: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
key: torch.Tensor,
loc: torch.Tensor,
req_to_token: torch.Tensor,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
scaling=None,
enable_gqa=False,
causal=False,
):
# set kv cache
k_cache[loc] = key
# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
query = query.movedim(0, query.dim() - 2)
start_q, start_kv = 0, 0
for seq_idx in range(seq_lens.shape[0]):
seq_len_q = 1
seq_len_kv = seq_lens[seq_idx]
end_q = start_q + seq_len_q
end_kv = start_kv + seq_len_kv
per_req_query = query[:, start_q:end_q, :]
# get key and value from cache. per_req_tokens contains the kv cache
# index for each token in the sequence.
req_pool_idx = req_pool_indices[seq_idx]
per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
per_req_out = (
scaled_dot_product_attention(
per_req_query.unsqueeze(0),
per_req_key.unsqueeze(0),
per_req_value.unsqueeze(0),
enable_gqa=enable_gqa,
scale=scaling,
is_causal=causal,
)
.squeeze(0)
.movedim(query.dim() - 2, 0)
)
output[start_q:end_q, :, :] = per_req_out
start_q, start_kv = end_q, end_kv
return output
def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V, seq_len):
dtype = torch.bfloat16
total_tokens = B * seq_len
sm_scale = 1.0 / (D**0.5)
logit_cap = 0.0
num_kv_splits = 8
enable_gqa = H_Q != H_KV
# q represents the new token being generated, one per batch
q = torch.randn(B, H_Q, D, dtype=dtype)
# k_buffer and v_buffer represent all previous tokens
k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype)
v_buffer = k_buffer.narrow(2, 0, D_V)
key = torch.randn(B, H_KV, D, dtype=dtype)
value = key.narrow(2, 0, D_V)
# make sure no duplicates in loc
loc = torch.randperm(total_tokens)[:B].to(torch.int64)
k_buffer2 = k_buffer.clone()
v_buffer2 = k_buffer2.narrow(2, 0, D_V)
# o will have the same shape as q
o = torch.zeros(B, H_Q, D_V, dtype=dtype)
o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype)
req_to_token = torch.arange(total_tokens).reshape(B, seq_len).to(torch.int32)
b_req_idx = torch.arange(B).to(torch.int64)
b_seq_len = torch.full((B,), seq_len).to(torch.int64)
attn_logits = torch.empty(
(B, H_Q, num_kv_splits, D_V + 1),
dtype=torch.float32,
)
torch.ops.sgl_kernel.decode_attention_cpu(
q,
k_buffer2,
v_buffer2,
o,
key,
value,
loc,
attn_logits,
req_to_token,
b_req_idx,
b_seq_len,
sm_scale,
logit_cap,
)
self._run_sdpa_forward_decode(
q,
o_grouped,
k_buffer,
v_buffer,
key,
loc,
req_to_token,
b_req_idx,
b_seq_len,
scaling=sm_scale,
enable_gqa=enable_gqa,
)
cos_sim = torch.nn.functional.cosine_similarity(
o.flatten(), o_grouped.flatten(), dim=0
)
atol = rtol = precision[q.dtype]
self.assertGreater(cos_sim.item(), 0.99)
torch.testing.assert_close(o, o_grouped, atol=atol, rtol=rtol)
torch.testing.assert_close(k_buffer, k_buffer2, atol=atol, rtol=rtol)
torch.testing.assert_close(v_buffer, v_buffer2, atol=atol, rtol=rtol)
def test_grouped_decode_attention(self):
configs = [
(1, 22, 1, 576, 512, 8 * 111),
(4, 22, 1, 576, 512, 8 * 128),
(40, 22, 1, 576, 512, 8 * 133),
]
for B, H_Q, H_KV, D, D_V, seqlen in configs:
self._test_grouped_decode_attention_once(B, H_Q, H_KV, D, D_V, seqlen)
if __name__ == "__main__":
unittest.main()

View File

@@ -33,7 +33,7 @@ def fused_moe(a, w1, w2, score, topk, renormalize, prepack):
topk_weights = torch.empty(B, topk, dtype=torch.float32)
topk_ids = torch.empty(B, topk, dtype=torch.int32)
topk_weights, topk_ids = kernel.grouped_topk_cpu(
a, score, topk, renormalize, G, topk_group
a, score, topk, renormalize, G, topk_group, 0, None, None
)
packed_w1 = kernel.convert_weight_packed(w1) if prepack else w1

View File

@@ -4,7 +4,7 @@ from typing import Optional, Tuple, Union
import sgl_kernel
import torch
from utils import precision
from utils import make_non_contiguous, precision
from sglang.test.test_utils import CustomTestCase
@@ -38,6 +38,7 @@ class TestNorm(CustomTestCase):
def _norm_test(self, m, n, dtype):
x = torch.randn([m, n], dtype=dtype)
x = make_non_contiguous(x)
hidden_size = x.size(-1)
weight = torch.randn(hidden_size, dtype=dtype)
variance_epsilon = 1e-6
@@ -49,7 +50,7 @@ class TestNorm(CustomTestCase):
self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol))
ref_x = x.clone()
residual = torch.randn([m, n], dtype=dtype)
residual = torch.randn([m, hidden_size], dtype=dtype)
ref_residual = residual.clone()
torch.ops.sgl_kernel.fused_add_rmsnorm_cpu(

View File

@@ -14,6 +14,7 @@ from sglang.test.test_utils import CustomTestCase
convert_weight_packed = torch.ops.sgl_kernel.convert_weight_packed
qkv_proj_with_rope = torch.ops.sgl_kernel.qkv_proj_with_rope
qkv_proj_with_rope_fused_weight = torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight
torch.manual_seed(0)
# constants
kv_lora_rank = 512
@@ -148,6 +149,7 @@ class TestQKVProjWithROPE(CustomTestCase):
kv_a_proj_weight = (
torch.randn(kv_lora_rank + qk_rope_head_dim, hidden_size, dtype=dtype) * 0.1
)
fused_weight = torch.cat([q_a_proj_weight, kv_a_proj_weight], dim=0)
norm_weight2 = torch.randn(kv_lora_rank, dtype=dtype)
pos = torch.randint(10, 100, (B,))
cos_sin_cache = torch.randn(100, rotary_dim, dtype=dtype)
@@ -167,6 +169,7 @@ class TestQKVProjWithROPE(CustomTestCase):
qb_packed = convert_weight_packed(q_b_proj_weight)
kva_packed = convert_weight_packed(kv_a_proj_weight)
wkc_packed = convert_weight_packed(w_kc)
fused_weight_packed = convert_weight_packed(fused_weight)
q_out, k_out, v_out = qkv_proj_with_rope(
hidden_states,
@@ -187,10 +190,33 @@ class TestQKVProjWithROPE(CustomTestCase):
True,
None,
)
fused_q_out, fused_k_out, fused_v_out = qkv_proj_with_rope_fused_weight(
hidden_states,
fused_weight_packed,
qb_packed,
wkc_packed,
norm_weight1,
norm_weight2,
pos,
cos_sin_cache,
eps,
False,
False,
None,
None,
True,
None,
q_lora_rank,
kv_lora_rank,
qk_rope_head_dim,
)
atol = rtol = precision[q_ref.dtype]
self.assertTrue(torch.allclose(q_ref, q_out, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(k_ref, k_out, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(v_ref, v_out, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(fused_q_out, q_out))
self.assertTrue(torch.allclose(fused_k_out, k_out))
self.assertTrue(torch.allclose(fused_v_out, v_out))
def test_int8_qkv_proj_with_rope(self):
dtype = torch.bfloat16
@@ -252,10 +278,36 @@ class TestQKVProjWithROPE(CustomTestCase):
True,
None,
)
fused_weight = torch.cat([w1_q, w3_q], dim=0)
fused_weight_s = torch.cat([w1_s, w3_s], dim=0)
w_fused_q_packed = convert_weight_packed(fused_weight)
fused_q_out, fused_k_out, fused_v_out = qkv_proj_with_rope_fused_weight(
hidden_states,
w_fused_q_packed,
w2_q_packed,
wkc_packed,
norm_weight1,
norm_weight2,
pos,
cos_sin_cache,
eps,
True,
False,
fused_weight_s,
w2_s,
True,
None,
q_lora_rank,
kv_lora_rank,
qk_rope_head_dim,
)
atol = rtol = precision[q_ref.dtype]
self.assertTrue(torch.allclose(q_ref, q_out, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(k_ref, k_out, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(v_ref, v_out, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(fused_q_out, q_out))
self.assertTrue(torch.allclose(fused_k_out, k_out))
self.assertTrue(torch.allclose(fused_v_out, v_out))
def test_fp8_qkv_proj_with_rope(self):
dtype = torch.bfloat16
@@ -311,17 +363,17 @@ class TestQKVProjWithROPE(CustomTestCase):
pos,
cos_sin_cache,
)
fp8_q_a_proj_weight = convert_weight_packed(fp8_q_a_proj_weight)
fp8_q_b_proj_weight = convert_weight_packed(fp8_q_b_proj_weight)
fp8_kv_a_proj_with_mqa_weight = convert_weight_packed(
fp8_q_a_proj_weight_packed = convert_weight_packed(fp8_q_a_proj_weight)
fp8_q_b_proj_weight_packed = convert_weight_packed(fp8_q_b_proj_weight)
fp8_kv_a_proj_with_mqa_weight_packed = convert_weight_packed(
fp8_kv_a_proj_with_mqa_weight
)
w_kc = convert_weight_packed(w_kc)
q_out, k_out, v_out = qkv_proj_with_rope(
hidden_states,
fp8_q_a_proj_weight,
fp8_q_b_proj_weight,
fp8_kv_a_proj_with_mqa_weight,
fp8_q_a_proj_weight_packed,
fp8_q_b_proj_weight_packed,
fp8_kv_a_proj_with_mqa_weight_packed,
w_kc,
norm_weight1,
norm_weight2,
@@ -336,10 +388,44 @@ class TestQKVProjWithROPE(CustomTestCase):
True,
[scale_block_size_N, scale_block_size_K],
)
fused_weight = torch.cat(
[fp8_q_a_proj_weight, fp8_kv_a_proj_with_mqa_weight], dim=0
)
fused_weight_s = torch.cat(
[q_a_proj_weight_scale_inv, kv_a_proj_with_mqa_weight_scale_inv], dim=0
)
fused_weight_packed = convert_weight_packed(fused_weight)
fused_q_out, fused_k_out, fused_v_out = qkv_proj_with_rope_fused_weight(
hidden_states,
fused_weight_packed,
fp8_q_b_proj_weight_packed,
w_kc,
norm_weight1,
norm_weight2,
pos,
cos_sin_cache,
eps,
False,
True,
fused_weight_s.float(),
q_b_proj_weight_scale_inv.float(),
True,
[scale_block_size_N, scale_block_size_K],
q_lora_rank,
kv_lora_rank,
qk_rope_head_dim,
)
atol = rtol = precision[q_ref.dtype]
self.assertTrue(torch.allclose(q_ref, q_out, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(k_ref, k_out, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(v_ref, v_out, atol=atol, rtol=rtol))
# Due to the change in multiplication order, the error is amplified.
# In the model, with fewer layers, this doesn't cause issues, but in
# tests with more layers, we need to enlarge the tolerance to pass the tests.
torch.testing.assert_close(q_ref, q_out, atol=1e-1, rtol=1e-1)
torch.testing.assert_close(k_ref, k_out, atol=atol, rtol=rtol)
torch.testing.assert_close(v_ref, v_out, atol=atol, rtol=rtol)
torch.testing.assert_close(fused_q_out, q_out)
torch.testing.assert_close(fused_k_out, k_out)
torch.testing.assert_close(fused_v_out, v_out)
if __name__ == "__main__":

View File

@@ -34,7 +34,15 @@ class TestGroupedTopK(CustomTestCase):
# fused version
topk_weights, topk_ids = torch.ops.sgl_kernel.grouped_topk_cpu(
hidden_states, gating_output, topk, renormalize, G, topk_group
hidden_states,
gating_output,
topk,
renormalize,
G,
topk_group,
0,
None,
None,
)
res = torch.zeros(M, E, dtype=torch.float)
@@ -83,6 +91,9 @@ class TestBiasedGroupedTopK(CustomTestCase):
renormalize,
G,
topk_group,
0,
None,
None,
)
res = torch.zeros(M, E, dtype=torch.float)

View File

@@ -244,3 +244,11 @@ def native_fp8_fused_moe(a, w1, w2, topk_weight, topk_ids, topk):
.sum(dim=1)
.to(a.dtype)
)
def make_non_contiguous(x: torch.Tensor) -> torch.Tensor:
"""
Make a tensor non-contiguous by slicing it via last dimension.
"""
last_dim = x.shape[-1]
return x[..., : last_dim // 2] if x.is_contiguous() else x