CPU: map changes from developing branch in sgl-kernel (#6833)
Co-authored-by: mingfeima <mingfei.ma@intel.com>
This commit is contained in:
155
test/srt/cpu/test_mla.py
Normal file
155
test/srt/cpu/test_mla.py
Normal file
@@ -0,0 +1,155 @@
|
||||
import itertools
|
||||
import unittest
|
||||
|
||||
import sgl_kernel
|
||||
import torch
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
from utils import precision
|
||||
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
class TestMLA(CustomTestCase):
|
||||
def _run_sdpa_forward_decode(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
loc: torch.Tensor,
|
||||
req_to_token: torch.Tensor,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
scaling=None,
|
||||
enable_gqa=False,
|
||||
causal=False,
|
||||
):
|
||||
# set kv cache
|
||||
k_cache[loc] = key
|
||||
|
||||
# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
|
||||
query = query.movedim(0, query.dim() - 2)
|
||||
|
||||
start_q, start_kv = 0, 0
|
||||
for seq_idx in range(seq_lens.shape[0]):
|
||||
seq_len_q = 1
|
||||
seq_len_kv = seq_lens[seq_idx]
|
||||
end_q = start_q + seq_len_q
|
||||
end_kv = start_kv + seq_len_kv
|
||||
|
||||
per_req_query = query[:, start_q:end_q, :]
|
||||
|
||||
# get key and value from cache. per_req_tokens contains the kv cache
|
||||
# index for each token in the sequence.
|
||||
req_pool_idx = req_pool_indices[seq_idx]
|
||||
per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
|
||||
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
||||
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
||||
|
||||
per_req_out = (
|
||||
scaled_dot_product_attention(
|
||||
per_req_query.unsqueeze(0),
|
||||
per_req_key.unsqueeze(0),
|
||||
per_req_value.unsqueeze(0),
|
||||
enable_gqa=enable_gqa,
|
||||
scale=scaling,
|
||||
is_causal=causal,
|
||||
)
|
||||
.squeeze(0)
|
||||
.movedim(query.dim() - 2, 0)
|
||||
)
|
||||
output[start_q:end_q, :, :] = per_req_out
|
||||
start_q, start_kv = end_q, end_kv
|
||||
|
||||
return output
|
||||
|
||||
def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V, seq_len):
|
||||
dtype = torch.bfloat16
|
||||
|
||||
total_tokens = B * seq_len
|
||||
sm_scale = 1.0 / (D**0.5)
|
||||
logit_cap = 0.0
|
||||
num_kv_splits = 8
|
||||
enable_gqa = H_Q != H_KV
|
||||
|
||||
# q represents the new token being generated, one per batch
|
||||
q = torch.randn(B, H_Q, D, dtype=dtype)
|
||||
|
||||
# k_buffer and v_buffer represent all previous tokens
|
||||
k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype)
|
||||
v_buffer = k_buffer.narrow(2, 0, D_V)
|
||||
|
||||
key = torch.randn(B, H_KV, D, dtype=dtype)
|
||||
value = key.narrow(2, 0, D_V)
|
||||
# make sure no duplicates in loc
|
||||
loc = torch.randperm(total_tokens)[:B].to(torch.int64)
|
||||
|
||||
k_buffer2 = k_buffer.clone()
|
||||
v_buffer2 = k_buffer2.narrow(2, 0, D_V)
|
||||
|
||||
# o will have the same shape as q
|
||||
o = torch.zeros(B, H_Q, D_V, dtype=dtype)
|
||||
o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype)
|
||||
|
||||
req_to_token = torch.arange(total_tokens).reshape(B, seq_len).to(torch.int32)
|
||||
b_req_idx = torch.arange(B).to(torch.int64)
|
||||
b_seq_len = torch.full((B,), seq_len).to(torch.int64)
|
||||
|
||||
attn_logits = torch.empty(
|
||||
(B, H_Q, num_kv_splits, D_V + 1),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
torch.ops.sgl_kernel.decode_attention_cpu(
|
||||
q,
|
||||
k_buffer2,
|
||||
v_buffer2,
|
||||
o,
|
||||
key,
|
||||
value,
|
||||
loc,
|
||||
attn_logits,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_seq_len,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
)
|
||||
|
||||
self._run_sdpa_forward_decode(
|
||||
q,
|
||||
o_grouped,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
key,
|
||||
loc,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_seq_len,
|
||||
scaling=sm_scale,
|
||||
enable_gqa=enable_gqa,
|
||||
)
|
||||
|
||||
cos_sim = torch.nn.functional.cosine_similarity(
|
||||
o.flatten(), o_grouped.flatten(), dim=0
|
||||
)
|
||||
atol = rtol = precision[q.dtype]
|
||||
self.assertGreater(cos_sim.item(), 0.99)
|
||||
torch.testing.assert_close(o, o_grouped, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(k_buffer, k_buffer2, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(v_buffer, v_buffer2, atol=atol, rtol=rtol)
|
||||
|
||||
def test_grouped_decode_attention(self):
|
||||
configs = [
|
||||
(1, 22, 1, 576, 512, 8 * 111),
|
||||
(4, 22, 1, 576, 512, 8 * 128),
|
||||
(40, 22, 1, 576, 512, 8 * 133),
|
||||
]
|
||||
|
||||
for B, H_Q, H_KV, D, D_V, seqlen in configs:
|
||||
self._test_grouped_decode_attention_once(B, H_Q, H_KV, D, D_V, seqlen)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -33,7 +33,7 @@ def fused_moe(a, w1, w2, score, topk, renormalize, prepack):
|
||||
topk_weights = torch.empty(B, topk, dtype=torch.float32)
|
||||
topk_ids = torch.empty(B, topk, dtype=torch.int32)
|
||||
topk_weights, topk_ids = kernel.grouped_topk_cpu(
|
||||
a, score, topk, renormalize, G, topk_group
|
||||
a, score, topk, renormalize, G, topk_group, 0, None, None
|
||||
)
|
||||
|
||||
packed_w1 = kernel.convert_weight_packed(w1) if prepack else w1
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Optional, Tuple, Union
|
||||
|
||||
import sgl_kernel
|
||||
import torch
|
||||
from utils import precision
|
||||
from utils import make_non_contiguous, precision
|
||||
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
@@ -38,6 +38,7 @@ class TestNorm(CustomTestCase):
|
||||
def _norm_test(self, m, n, dtype):
|
||||
|
||||
x = torch.randn([m, n], dtype=dtype)
|
||||
x = make_non_contiguous(x)
|
||||
hidden_size = x.size(-1)
|
||||
weight = torch.randn(hidden_size, dtype=dtype)
|
||||
variance_epsilon = 1e-6
|
||||
@@ -49,7 +50,7 @@ class TestNorm(CustomTestCase):
|
||||
self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol))
|
||||
|
||||
ref_x = x.clone()
|
||||
residual = torch.randn([m, n], dtype=dtype)
|
||||
residual = torch.randn([m, hidden_size], dtype=dtype)
|
||||
ref_residual = residual.clone()
|
||||
|
||||
torch.ops.sgl_kernel.fused_add_rmsnorm_cpu(
|
||||
|
||||
@@ -14,6 +14,7 @@ from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
convert_weight_packed = torch.ops.sgl_kernel.convert_weight_packed
|
||||
qkv_proj_with_rope = torch.ops.sgl_kernel.qkv_proj_with_rope
|
||||
qkv_proj_with_rope_fused_weight = torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight
|
||||
torch.manual_seed(0)
|
||||
# constants
|
||||
kv_lora_rank = 512
|
||||
@@ -148,6 +149,7 @@ class TestQKVProjWithROPE(CustomTestCase):
|
||||
kv_a_proj_weight = (
|
||||
torch.randn(kv_lora_rank + qk_rope_head_dim, hidden_size, dtype=dtype) * 0.1
|
||||
)
|
||||
fused_weight = torch.cat([q_a_proj_weight, kv_a_proj_weight], dim=0)
|
||||
norm_weight2 = torch.randn(kv_lora_rank, dtype=dtype)
|
||||
pos = torch.randint(10, 100, (B,))
|
||||
cos_sin_cache = torch.randn(100, rotary_dim, dtype=dtype)
|
||||
@@ -167,6 +169,7 @@ class TestQKVProjWithROPE(CustomTestCase):
|
||||
qb_packed = convert_weight_packed(q_b_proj_weight)
|
||||
kva_packed = convert_weight_packed(kv_a_proj_weight)
|
||||
wkc_packed = convert_weight_packed(w_kc)
|
||||
fused_weight_packed = convert_weight_packed(fused_weight)
|
||||
|
||||
q_out, k_out, v_out = qkv_proj_with_rope(
|
||||
hidden_states,
|
||||
@@ -187,10 +190,33 @@ class TestQKVProjWithROPE(CustomTestCase):
|
||||
True,
|
||||
None,
|
||||
)
|
||||
fused_q_out, fused_k_out, fused_v_out = qkv_proj_with_rope_fused_weight(
|
||||
hidden_states,
|
||||
fused_weight_packed,
|
||||
qb_packed,
|
||||
wkc_packed,
|
||||
norm_weight1,
|
||||
norm_weight2,
|
||||
pos,
|
||||
cos_sin_cache,
|
||||
eps,
|
||||
False,
|
||||
False,
|
||||
None,
|
||||
None,
|
||||
True,
|
||||
None,
|
||||
q_lora_rank,
|
||||
kv_lora_rank,
|
||||
qk_rope_head_dim,
|
||||
)
|
||||
atol = rtol = precision[q_ref.dtype]
|
||||
self.assertTrue(torch.allclose(q_ref, q_out, atol=atol, rtol=rtol))
|
||||
self.assertTrue(torch.allclose(k_ref, k_out, atol=atol, rtol=rtol))
|
||||
self.assertTrue(torch.allclose(v_ref, v_out, atol=atol, rtol=rtol))
|
||||
self.assertTrue(torch.allclose(fused_q_out, q_out))
|
||||
self.assertTrue(torch.allclose(fused_k_out, k_out))
|
||||
self.assertTrue(torch.allclose(fused_v_out, v_out))
|
||||
|
||||
def test_int8_qkv_proj_with_rope(self):
|
||||
dtype = torch.bfloat16
|
||||
@@ -252,10 +278,36 @@ class TestQKVProjWithROPE(CustomTestCase):
|
||||
True,
|
||||
None,
|
||||
)
|
||||
fused_weight = torch.cat([w1_q, w3_q], dim=0)
|
||||
fused_weight_s = torch.cat([w1_s, w3_s], dim=0)
|
||||
w_fused_q_packed = convert_weight_packed(fused_weight)
|
||||
fused_q_out, fused_k_out, fused_v_out = qkv_proj_with_rope_fused_weight(
|
||||
hidden_states,
|
||||
w_fused_q_packed,
|
||||
w2_q_packed,
|
||||
wkc_packed,
|
||||
norm_weight1,
|
||||
norm_weight2,
|
||||
pos,
|
||||
cos_sin_cache,
|
||||
eps,
|
||||
True,
|
||||
False,
|
||||
fused_weight_s,
|
||||
w2_s,
|
||||
True,
|
||||
None,
|
||||
q_lora_rank,
|
||||
kv_lora_rank,
|
||||
qk_rope_head_dim,
|
||||
)
|
||||
atol = rtol = precision[q_ref.dtype]
|
||||
self.assertTrue(torch.allclose(q_ref, q_out, atol=atol, rtol=rtol))
|
||||
self.assertTrue(torch.allclose(k_ref, k_out, atol=atol, rtol=rtol))
|
||||
self.assertTrue(torch.allclose(v_ref, v_out, atol=atol, rtol=rtol))
|
||||
self.assertTrue(torch.allclose(fused_q_out, q_out))
|
||||
self.assertTrue(torch.allclose(fused_k_out, k_out))
|
||||
self.assertTrue(torch.allclose(fused_v_out, v_out))
|
||||
|
||||
def test_fp8_qkv_proj_with_rope(self):
|
||||
dtype = torch.bfloat16
|
||||
@@ -311,17 +363,17 @@ class TestQKVProjWithROPE(CustomTestCase):
|
||||
pos,
|
||||
cos_sin_cache,
|
||||
)
|
||||
fp8_q_a_proj_weight = convert_weight_packed(fp8_q_a_proj_weight)
|
||||
fp8_q_b_proj_weight = convert_weight_packed(fp8_q_b_proj_weight)
|
||||
fp8_kv_a_proj_with_mqa_weight = convert_weight_packed(
|
||||
fp8_q_a_proj_weight_packed = convert_weight_packed(fp8_q_a_proj_weight)
|
||||
fp8_q_b_proj_weight_packed = convert_weight_packed(fp8_q_b_proj_weight)
|
||||
fp8_kv_a_proj_with_mqa_weight_packed = convert_weight_packed(
|
||||
fp8_kv_a_proj_with_mqa_weight
|
||||
)
|
||||
w_kc = convert_weight_packed(w_kc)
|
||||
q_out, k_out, v_out = qkv_proj_with_rope(
|
||||
hidden_states,
|
||||
fp8_q_a_proj_weight,
|
||||
fp8_q_b_proj_weight,
|
||||
fp8_kv_a_proj_with_mqa_weight,
|
||||
fp8_q_a_proj_weight_packed,
|
||||
fp8_q_b_proj_weight_packed,
|
||||
fp8_kv_a_proj_with_mqa_weight_packed,
|
||||
w_kc,
|
||||
norm_weight1,
|
||||
norm_weight2,
|
||||
@@ -336,10 +388,44 @@ class TestQKVProjWithROPE(CustomTestCase):
|
||||
True,
|
||||
[scale_block_size_N, scale_block_size_K],
|
||||
)
|
||||
|
||||
fused_weight = torch.cat(
|
||||
[fp8_q_a_proj_weight, fp8_kv_a_proj_with_mqa_weight], dim=0
|
||||
)
|
||||
fused_weight_s = torch.cat(
|
||||
[q_a_proj_weight_scale_inv, kv_a_proj_with_mqa_weight_scale_inv], dim=0
|
||||
)
|
||||
fused_weight_packed = convert_weight_packed(fused_weight)
|
||||
fused_q_out, fused_k_out, fused_v_out = qkv_proj_with_rope_fused_weight(
|
||||
hidden_states,
|
||||
fused_weight_packed,
|
||||
fp8_q_b_proj_weight_packed,
|
||||
w_kc,
|
||||
norm_weight1,
|
||||
norm_weight2,
|
||||
pos,
|
||||
cos_sin_cache,
|
||||
eps,
|
||||
False,
|
||||
True,
|
||||
fused_weight_s.float(),
|
||||
q_b_proj_weight_scale_inv.float(),
|
||||
True,
|
||||
[scale_block_size_N, scale_block_size_K],
|
||||
q_lora_rank,
|
||||
kv_lora_rank,
|
||||
qk_rope_head_dim,
|
||||
)
|
||||
atol = rtol = precision[q_ref.dtype]
|
||||
self.assertTrue(torch.allclose(q_ref, q_out, atol=atol, rtol=rtol))
|
||||
self.assertTrue(torch.allclose(k_ref, k_out, atol=atol, rtol=rtol))
|
||||
self.assertTrue(torch.allclose(v_ref, v_out, atol=atol, rtol=rtol))
|
||||
# Due to the change in multiplication order, the error is amplified.
|
||||
# In the model, with fewer layers, this doesn't cause issues, but in
|
||||
# tests with more layers, we need to enlarge the tolerance to pass the tests.
|
||||
torch.testing.assert_close(q_ref, q_out, atol=1e-1, rtol=1e-1)
|
||||
torch.testing.assert_close(k_ref, k_out, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(v_ref, v_out, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(fused_q_out, q_out)
|
||||
torch.testing.assert_close(fused_k_out, k_out)
|
||||
torch.testing.assert_close(fused_v_out, v_out)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -34,7 +34,15 @@ class TestGroupedTopK(CustomTestCase):
|
||||
|
||||
# fused version
|
||||
topk_weights, topk_ids = torch.ops.sgl_kernel.grouped_topk_cpu(
|
||||
hidden_states, gating_output, topk, renormalize, G, topk_group
|
||||
hidden_states,
|
||||
gating_output,
|
||||
topk,
|
||||
renormalize,
|
||||
G,
|
||||
topk_group,
|
||||
0,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
res = torch.zeros(M, E, dtype=torch.float)
|
||||
@@ -83,6 +91,9 @@ class TestBiasedGroupedTopK(CustomTestCase):
|
||||
renormalize,
|
||||
G,
|
||||
topk_group,
|
||||
0,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
res = torch.zeros(M, E, dtype=torch.float)
|
||||
|
||||
@@ -244,3 +244,11 @@ def native_fp8_fused_moe(a, w1, w2, topk_weight, topk_ids, topk):
|
||||
.sum(dim=1)
|
||||
.to(a.dtype)
|
||||
)
|
||||
|
||||
|
||||
def make_non_contiguous(x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Make a tensor non-contiguous by slicing it via last dimension.
|
||||
"""
|
||||
last_dim = x.shape[-1]
|
||||
return x[..., : last_dim // 2] if x.is_contiguous() else x
|
||||
|
||||
Reference in New Issue
Block a user