diff --git a/sgl-kernel/tests/test_merge_state.py b/sgl-kernel/tests/test_merge_state.py index 2931fa949..70b9628d9 100644 --- a/sgl-kernel/tests/test_merge_state.py +++ b/sgl-kernel/tests/test_merge_state.py @@ -136,3 +136,7 @@ def test_merge_state(seq_len, num_heads, head_dim): assert torch.allclose(v_merged, v_merged_std, atol=1e-2) assert torch.allclose(s_merged, s_merged_std, atol=1e-2) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_merge_state_v2.py b/sgl-kernel/tests/test_merge_state_v2.py index f5c7a30dd..62326f75c 100644 --- a/sgl-kernel/tests/test_merge_state_v2.py +++ b/sgl-kernel/tests/test_merge_state_v2.py @@ -394,3 +394,7 @@ def test_merge_attn_states( len(NUM_BATCH_TOKENS) * len(HEAD_SIZES) * len(NUM_QUERY_HEADS) * len(DTYPES) ): generate_markdown_table() + + +if __name__ == "__main__": + pytest.main([__file__])