diff --git a/sgl-kernel/tests/speculative/test_eagle_utils.py b/sgl-kernel/tests/speculative/test_eagle_utils.py index 12aa2e498..03e6825de 100644 --- a/sgl-kernel/tests/speculative/test_eagle_utils.py +++ b/sgl-kernel/tests/speculative/test_eagle_utils.py @@ -49,7 +49,6 @@ def test_verify_tree_greedy(): if torch.max(target_logits[i][j]) < 10: target_logits[i][j][18] = 10 - print(f"{target_logits=}") target_predict = torch.argmax(target_logits, dim=-1).to(torch.int32) predict_shape = (12,) @@ -65,12 +64,6 @@ def test_verify_tree_greedy(): ) # mutable accept_token_num = torch.full((bs,), 0, dtype=torch.int32, device="cuda") # mutable - print(f"{candidates=}") - print(f"{retrive_index=}") - print(f"{retrive_next_token=}") - print(f"{retrive_next_sibling=}") - print(f"{target_predict=}") - verify_tree_greedy( predicts=predicts, accept_index=accept_index, @@ -82,10 +75,6 @@ def test_verify_tree_greedy(): target_predict=target_predict, ) - print(f"{predicts=}") - print(f"{accept_index=}") - print(f"{accept_token_num=}") - # Check the expected output. assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18] assert accept_index.tolist() == [ diff --git a/sgl-kernel/tests/speculative/test_speculative_sampling.py b/sgl-kernel/tests/speculative/test_speculative_sampling.py index 93f3f5093..56dd02b84 100644 --- a/sgl-kernel/tests/speculative/test_speculative_sampling.py +++ b/sgl-kernel/tests/speculative/test_speculative_sampling.py @@ -3,18 +3,47 @@ import torch import torch.nn.functional as F from sgl_kernel import tree_speculative_sampling_target_only +test_cases = [ + ( + 1, + 1, + [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18], + [[0, 3, 4, 5], [6, 10, 11, -1]], + [3, 2], + ), + ( + 0, # threshold_single + 0, # threshold_acc + [1, 2, 18, -1, -1, -1, 11, -1, -1, -1, 12, 18], + [[0, 1, 2, -1], [6, 10, 11, -1]], + [2, 2], + ), +] + + +@pytest.mark.parametrize( + "threshold_single, threshold_acc, expected_predicts, expected_accept_index, expected_accept_token_num", + test_cases, +) +def test_tree_speculative_sampling_target_only( + threshold_single, + threshold_acc, + expected_predicts, + expected_accept_index, + expected_accept_token_num, +): + """ + Tests the tree_speculative_sampling_target_only function using Pytest parameterization. + """ + device = "cuda" -def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc=1): - print( - f"\n============= run test: {threshold_single=} {threshold_acc=} ==============\n" - ) candidates = torch.tensor( [ [0, 1, 2, 3, 4, 5], [7, 8, 9, 10, 11, 12], ], dtype=torch.int32, - device="cuda", + device=device, ) retrive_index = torch.tensor( [ @@ -22,7 +51,7 @@ def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc [6, 7, 8, 9, 10, 11], ], dtype=torch.int32, - device="cuda", + device=device, ) retrive_next_token = torch.tensor( [ @@ -30,7 +59,7 @@ def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc [4, 2, 3, -1, 5, -1], ], dtype=torch.int32, - device="cuda", + device=device, ) retrive_next_sibling = torch.tensor( [ @@ -38,45 +67,34 @@ def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc [-1, -1, -1, -1, 1, -1], ], dtype=torch.int32, - device="cuda", + device=device, ) - target_logits = torch.full((2, 6, 20), 1, dtype=torch.float32, device="cuda") + target_logits = torch.full((2, 6, 20), 1, dtype=torch.float32, device=device) target_logits[0, 0, 3] = 10 target_logits[0, 3, 4] = 10 target_logits[0, 4, 5] = 10 target_logits[1, 0, 11] = 10 target_logits[1, 4, 12] = 10 + for i in range(target_logits.shape[0]): for j in range(target_logits.shape[1]): - if torch.max(target_logits[i][j]) < 10: - target_logits[i][j][18] = 10 + if torch.max(target_logits[i, j]) < 10: + target_logits[i, j, 18] = 10 - temperatures = torch.tensor([0.01, 0.01], dtype=torch.float32, device="cuda") - predict_shape = (12,) + temperatures = torch.tensor([0.01, 0.01], dtype=torch.float32, device=device) + bs, num_draft_tokens = candidates.shape + num_spec_step = len(expected_accept_index[0]) + predict_shape = (len(expected_predicts),) - bs = candidates.shape[0] - num_spec_step = 4 - num_draft_tokens = candidates.shape[1] - - predicts = torch.full( - predict_shape, -1, dtype=torch.int32, device="cuda" - ) # mutable - accept_index = torch.full( - (bs, num_spec_step), -1, dtype=torch.int32, device="cuda" - ) # mutable - accept_token_num = torch.full((bs,), 0, dtype=torch.int32, device="cuda") # mutable + predicts = torch.full(predict_shape, -1, dtype=torch.int32, device=device) + accept_index = torch.full((bs, num_spec_step), -1, dtype=torch.int32, device=device) + accept_token_num = torch.full((bs,), 0, dtype=torch.int32, device=device) expanded_temperature = temperatures.unsqueeze(1).unsqueeze(1) target_probs = F.softmax(target_logits / expanded_temperature, dim=-1) - draft_probs = torch.full_like(target_probs, 0, dtype=torch.float32, device="cuda") - - coins = torch.rand(bs, num_draft_tokens, device="cuda").to(torch.float32) - print(f"{candidates=}") - print(f"{retrive_index=}") - print(f"{retrive_next_token=}") - print(f"{retrive_next_sibling=}") - print(f"{coins=}") + draft_probs = torch.full_like(target_probs, 0, dtype=torch.float32, device=device) + coins = torch.rand(bs, num_draft_tokens, device=device, dtype=torch.float32) tree_speculative_sampling_target_only( predicts=predicts, @@ -94,24 +112,15 @@ def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc deterministic=True, ) - print(f"{predicts=}") - print(f"{accept_index=}") - print(f"{accept_token_num=}") - - if threshold_single == 1 and threshold_acc == 1: - assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18] - assert accept_index.tolist() == [ - [0, 3, 4, 5], - [6, 10, 11, -1], - ] - assert accept_token_num.tolist() == [3, 2] - elif threshold_single == 0 and threshold_acc == 0: - assert predicts.tolist() == [1, 2, 18, -1, -1, -1, 11, -1, -1, -1, 12, 18] - assert accept_index.tolist() == [ - [0, 1, 2, -1], - [6, 10, 11, -1], - ] - assert accept_token_num.tolist() == [2, 2] + assert ( + predicts.tolist() == expected_predicts + ), f"Predicts mismatch for thresholds ({threshold_single}, {threshold_acc})" + assert ( + accept_index.tolist() == expected_accept_index + ), f"Accept index mismatch for thresholds ({threshold_single}, {threshold_acc})" + assert ( + accept_token_num.tolist() == expected_accept_token_num + ), f"Accept token num mismatch for thresholds ({threshold_single}, {threshold_acc})" if __name__ == "__main__": diff --git a/sgl-kernel/tests/test_fp8_blockwise_gemm.py b/sgl-kernel/tests/test_fp8_blockwise_gemm.py index 6872476f0..4c1dde336 100644 --- a/sgl-kernel/tests/test_fp8_blockwise_gemm.py +++ b/sgl-kernel/tests/test_fp8_blockwise_gemm.py @@ -79,7 +79,6 @@ def _test_accuracy_once(M, N, K, out_dtype, device): rtol = 0.02 atol = 1 torch.testing.assert_close(o, o1, rtol=rtol, atol=atol) - print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK") @pytest.mark.parametrize("M", [1, 3, 5, 127, 128, 512, 1024, 4096]) diff --git a/sgl-kernel/tests/test_int8_gemm.py b/sgl-kernel/tests/test_int8_gemm.py index 9f4103e1d..4d506faed 100644 --- a/sgl-kernel/tests/test_int8_gemm.py +++ b/sgl-kernel/tests/test_int8_gemm.py @@ -28,7 +28,6 @@ def _test_accuracy_once(M, N, K, with_bias, out_dtype, device): o = int8_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) o1 = torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) torch.testing.assert_close(o, o1) - print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK") @pytest.mark.parametrize("M", [1, 16, 32, 64, 128, 512, 1024, 4096, 8192]) diff --git a/sgl-kernel/tests/test_lightning_attention_decode.py b/sgl-kernel/tests/test_lightning_attention_decode.py index 8f5d4bb77..f2f0ba258 100644 --- a/sgl-kernel/tests/test_lightning_attention_decode.py +++ b/sgl-kernel/tests/test_lightning_attention_decode.py @@ -70,8 +70,6 @@ def test_lightning_attention_decode(dtype, batch_size, num_heads, dim, embed_dim ref_output, rtol=rtol, atol=atol, - msg=f"Output mismatch for batch_size={batch_size}, num_heads={num_heads}, " - f"dim={dim}, embed_dim={embed_dim}, dtype={dtype}", ) torch.testing.assert_close( @@ -79,8 +77,6 @@ def test_lightning_attention_decode(dtype, batch_size, num_heads, dim, embed_dim ref_new_kv, rtol=rtol, atol=atol, - msg=f"New KV mismatch for batch_size={batch_size}, num_heads={num_heads}, " - f"dim={dim}, embed_dim={embed_dim}, dtype={dtype}", ) diff --git a/sgl-kernel/tests/test_moe_topk_softmax.py b/sgl-kernel/tests/test_moe_topk_softmax.py index 09acde584..420a3a6d6 100644 --- a/sgl-kernel/tests/test_moe_topk_softmax.py +++ b/sgl-kernel/tests/test_moe_topk_softmax.py @@ -42,12 +42,10 @@ def test_topk_softmax(num_tokens, num_experts, topk): topk_weights_ref, topk_weights, atol=1e-3, rtol=1e-3 ), f"Weights mismatch: torch={topk_indices_ref} vs SGLang={topk_weights}" - assert torch.equal( - topk_indices_ref, topk_indices + assert torch.allclose( + topk_indices_ref.int(), topk_indices, atol=0, rtol=0 ), f"Indices mismatch: torch={topk_indices_ref}, SGLang={topk_indices}" - print("✅ Native torch and custom kernel implementations match.") - if __name__ == "__main__": pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_per_token_group_quant_8bit.py b/sgl-kernel/tests/test_per_token_group_quant_8bit.py index 083ca1cad..ba3bafda5 100644 --- a/sgl-kernel/tests/test_per_token_group_quant_8bit.py +++ b/sgl-kernel/tests/test_per_token_group_quant_8bit.py @@ -304,10 +304,10 @@ def test_per_token_group_quant_with_column_major( scale_tma_aligned=scale_tma_aligned, ) - assert torch.allclose( + torch.testing.assert_close( x_q_triton.to(torch.float32), x_q_sglang.to(torch.float32), rtol=1e-3, atol=1e-5 ) - assert torch.allclose( + torch.testing.assert_close( x_s_triton.contiguous(), x_s_sglang.contiguous(), rtol=1e-3, atol=1e-5 ) diff --git a/sgl-kernel/tests/test_rotary_embedding.py b/sgl-kernel/tests/test_rotary_embedding.py index fa937a604..539b51d84 100644 --- a/sgl-kernel/tests/test_rotary_embedding.py +++ b/sgl-kernel/tests/test_rotary_embedding.py @@ -187,9 +187,6 @@ def test_correctness( pos_ids, query_flashinfer, key_flashinfer ) - print(query_ref_out) - print(query_flashinfer_out) - torch.testing.assert_close( query_ref_out, query_flashinfer_out, atol=1e-2, rtol=1e-2 )