CPU: map changes from developing branch in sgl-kernel (#6833)
Co-authored-by: mingfeima <mingfei.ma@intel.com>
This commit is contained in:
@@ -14,6 +14,7 @@ from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
convert_weight_packed = torch.ops.sgl_kernel.convert_weight_packed
|
||||
qkv_proj_with_rope = torch.ops.sgl_kernel.qkv_proj_with_rope
|
||||
qkv_proj_with_rope_fused_weight = torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight
|
||||
torch.manual_seed(0)
|
||||
# constants
|
||||
kv_lora_rank = 512
|
||||
@@ -148,6 +149,7 @@ class TestQKVProjWithROPE(CustomTestCase):
|
||||
kv_a_proj_weight = (
|
||||
torch.randn(kv_lora_rank + qk_rope_head_dim, hidden_size, dtype=dtype) * 0.1
|
||||
)
|
||||
fused_weight = torch.cat([q_a_proj_weight, kv_a_proj_weight], dim=0)
|
||||
norm_weight2 = torch.randn(kv_lora_rank, dtype=dtype)
|
||||
pos = torch.randint(10, 100, (B,))
|
||||
cos_sin_cache = torch.randn(100, rotary_dim, dtype=dtype)
|
||||
@@ -167,6 +169,7 @@ class TestQKVProjWithROPE(CustomTestCase):
|
||||
qb_packed = convert_weight_packed(q_b_proj_weight)
|
||||
kva_packed = convert_weight_packed(kv_a_proj_weight)
|
||||
wkc_packed = convert_weight_packed(w_kc)
|
||||
fused_weight_packed = convert_weight_packed(fused_weight)
|
||||
|
||||
q_out, k_out, v_out = qkv_proj_with_rope(
|
||||
hidden_states,
|
||||
@@ -187,10 +190,33 @@ class TestQKVProjWithROPE(CustomTestCase):
|
||||
True,
|
||||
None,
|
||||
)
|
||||
fused_q_out, fused_k_out, fused_v_out = qkv_proj_with_rope_fused_weight(
|
||||
hidden_states,
|
||||
fused_weight_packed,
|
||||
qb_packed,
|
||||
wkc_packed,
|
||||
norm_weight1,
|
||||
norm_weight2,
|
||||
pos,
|
||||
cos_sin_cache,
|
||||
eps,
|
||||
False,
|
||||
False,
|
||||
None,
|
||||
None,
|
||||
True,
|
||||
None,
|
||||
q_lora_rank,
|
||||
kv_lora_rank,
|
||||
qk_rope_head_dim,
|
||||
)
|
||||
atol = rtol = precision[q_ref.dtype]
|
||||
self.assertTrue(torch.allclose(q_ref, q_out, atol=atol, rtol=rtol))
|
||||
self.assertTrue(torch.allclose(k_ref, k_out, atol=atol, rtol=rtol))
|
||||
self.assertTrue(torch.allclose(v_ref, v_out, atol=atol, rtol=rtol))
|
||||
self.assertTrue(torch.allclose(fused_q_out, q_out))
|
||||
self.assertTrue(torch.allclose(fused_k_out, k_out))
|
||||
self.assertTrue(torch.allclose(fused_v_out, v_out))
|
||||
|
||||
def test_int8_qkv_proj_with_rope(self):
|
||||
dtype = torch.bfloat16
|
||||
@@ -252,10 +278,36 @@ class TestQKVProjWithROPE(CustomTestCase):
|
||||
True,
|
||||
None,
|
||||
)
|
||||
fused_weight = torch.cat([w1_q, w3_q], dim=0)
|
||||
fused_weight_s = torch.cat([w1_s, w3_s], dim=0)
|
||||
w_fused_q_packed = convert_weight_packed(fused_weight)
|
||||
fused_q_out, fused_k_out, fused_v_out = qkv_proj_with_rope_fused_weight(
|
||||
hidden_states,
|
||||
w_fused_q_packed,
|
||||
w2_q_packed,
|
||||
wkc_packed,
|
||||
norm_weight1,
|
||||
norm_weight2,
|
||||
pos,
|
||||
cos_sin_cache,
|
||||
eps,
|
||||
True,
|
||||
False,
|
||||
fused_weight_s,
|
||||
w2_s,
|
||||
True,
|
||||
None,
|
||||
q_lora_rank,
|
||||
kv_lora_rank,
|
||||
qk_rope_head_dim,
|
||||
)
|
||||
atol = rtol = precision[q_ref.dtype]
|
||||
self.assertTrue(torch.allclose(q_ref, q_out, atol=atol, rtol=rtol))
|
||||
self.assertTrue(torch.allclose(k_ref, k_out, atol=atol, rtol=rtol))
|
||||
self.assertTrue(torch.allclose(v_ref, v_out, atol=atol, rtol=rtol))
|
||||
self.assertTrue(torch.allclose(fused_q_out, q_out))
|
||||
self.assertTrue(torch.allclose(fused_k_out, k_out))
|
||||
self.assertTrue(torch.allclose(fused_v_out, v_out))
|
||||
|
||||
def test_fp8_qkv_proj_with_rope(self):
|
||||
dtype = torch.bfloat16
|
||||
@@ -311,17 +363,17 @@ class TestQKVProjWithROPE(CustomTestCase):
|
||||
pos,
|
||||
cos_sin_cache,
|
||||
)
|
||||
fp8_q_a_proj_weight = convert_weight_packed(fp8_q_a_proj_weight)
|
||||
fp8_q_b_proj_weight = convert_weight_packed(fp8_q_b_proj_weight)
|
||||
fp8_kv_a_proj_with_mqa_weight = convert_weight_packed(
|
||||
fp8_q_a_proj_weight_packed = convert_weight_packed(fp8_q_a_proj_weight)
|
||||
fp8_q_b_proj_weight_packed = convert_weight_packed(fp8_q_b_proj_weight)
|
||||
fp8_kv_a_proj_with_mqa_weight_packed = convert_weight_packed(
|
||||
fp8_kv_a_proj_with_mqa_weight
|
||||
)
|
||||
w_kc = convert_weight_packed(w_kc)
|
||||
q_out, k_out, v_out = qkv_proj_with_rope(
|
||||
hidden_states,
|
||||
fp8_q_a_proj_weight,
|
||||
fp8_q_b_proj_weight,
|
||||
fp8_kv_a_proj_with_mqa_weight,
|
||||
fp8_q_a_proj_weight_packed,
|
||||
fp8_q_b_proj_weight_packed,
|
||||
fp8_kv_a_proj_with_mqa_weight_packed,
|
||||
w_kc,
|
||||
norm_weight1,
|
||||
norm_weight2,
|
||||
@@ -336,10 +388,44 @@ class TestQKVProjWithROPE(CustomTestCase):
|
||||
True,
|
||||
[scale_block_size_N, scale_block_size_K],
|
||||
)
|
||||
|
||||
fused_weight = torch.cat(
|
||||
[fp8_q_a_proj_weight, fp8_kv_a_proj_with_mqa_weight], dim=0
|
||||
)
|
||||
fused_weight_s = torch.cat(
|
||||
[q_a_proj_weight_scale_inv, kv_a_proj_with_mqa_weight_scale_inv], dim=0
|
||||
)
|
||||
fused_weight_packed = convert_weight_packed(fused_weight)
|
||||
fused_q_out, fused_k_out, fused_v_out = qkv_proj_with_rope_fused_weight(
|
||||
hidden_states,
|
||||
fused_weight_packed,
|
||||
fp8_q_b_proj_weight_packed,
|
||||
w_kc,
|
||||
norm_weight1,
|
||||
norm_weight2,
|
||||
pos,
|
||||
cos_sin_cache,
|
||||
eps,
|
||||
False,
|
||||
True,
|
||||
fused_weight_s.float(),
|
||||
q_b_proj_weight_scale_inv.float(),
|
||||
True,
|
||||
[scale_block_size_N, scale_block_size_K],
|
||||
q_lora_rank,
|
||||
kv_lora_rank,
|
||||
qk_rope_head_dim,
|
||||
)
|
||||
atol = rtol = precision[q_ref.dtype]
|
||||
self.assertTrue(torch.allclose(q_ref, q_out, atol=atol, rtol=rtol))
|
||||
self.assertTrue(torch.allclose(k_ref, k_out, atol=atol, rtol=rtol))
|
||||
self.assertTrue(torch.allclose(v_ref, v_out, atol=atol, rtol=rtol))
|
||||
# Due to the change in multiplication order, the error is amplified.
|
||||
# In the model, with fewer layers, this doesn't cause issues, but in
|
||||
# tests with more layers, we need to enlarge the tolerance to pass the tests.
|
||||
torch.testing.assert_close(q_ref, q_out, atol=1e-1, rtol=1e-1)
|
||||
torch.testing.assert_close(k_ref, k_out, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(v_ref, v_out, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(fused_q_out, q_out)
|
||||
torch.testing.assert_close(fused_k_out, k_out)
|
||||
torch.testing.assert_close(fused_v_out, v_out)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user