From 0da0989ad4f468ce35f4c8220241901a75ed1b26 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Thu, 23 Jan 2025 21:13:55 +0800 Subject: [PATCH] sync flashinfer and update sgl-kernel tests (#3081) --- .github/workflows/pr-test-sgl-kernel.yml | 2 +- sgl-kernel/3rdparty/flashinfer | 2 +- sgl-kernel/Makefile | 2 +- sgl-kernel/tests/test_activation.py | 3 ++- sgl-kernel/tests/test_lightning_attention_decode.py | 4 ++++ sgl-kernel/tests/test_norm.py | 4 ++++ 6 files changed, 13 insertions(+), 4 deletions(-) diff --git a/.github/workflows/pr-test-sgl-kernel.yml b/.github/workflows/pr-test-sgl-kernel.yml index 55eb636d6..aea609697 100644 --- a/.github/workflows/pr-test-sgl-kernel.yml +++ b/.github/workflows/pr-test-sgl-kernel.yml @@ -47,7 +47,7 @@ jobs: pip3 list | grep sgl-kernel - name: Run test - timeout-minutes: 10 + timeout-minutes: 30 run: | cd sgl-kernel find tests -name "test_*.py" | xargs -n 1 python3 diff --git a/sgl-kernel/3rdparty/flashinfer b/sgl-kernel/3rdparty/flashinfer index 4e8eb1879..93e1a2634 160000 --- a/sgl-kernel/3rdparty/flashinfer +++ b/sgl-kernel/3rdparty/flashinfer @@ -1 +1 @@ -Subproject commit 4e8eb1879f9c3ba6d75511e5893183bf8f289a62 +Subproject commit 93e1a2634e22355b0856246b032b285ad1d1da6b diff --git a/sgl-kernel/Makefile b/sgl-kernel/Makefile index 9261b8969..c7641bb5f 100644 --- a/sgl-kernel/Makefile +++ b/sgl-kernel/Makefile @@ -19,7 +19,7 @@ clean: @rm -rf build dist *.egg-info test: - @find tests -name "test_*.py" | xargs -n 1 python3 && pytest tests/test_norm.py && pytest tests/test_activation.py + @find tests -name "test_*.py" | xargs -n 1 python3 format: @find src tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' -o -name '*.hpp' | xargs clang-format -i && find src tests -name '*.py' | xargs isort && find src tests -name '*.py' | xargs black diff --git a/sgl-kernel/tests/test_activation.py b/sgl-kernel/tests/test_activation.py index f71f36b51..43593441e 100644 --- a/sgl-kernel/tests/test_activation.py +++ b/sgl-kernel/tests/test_activation.py @@ -35,4 +35,5 @@ def test_fused_gelu_mul(dim, batch_size, seq_len): torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) -test_fused_silu_mul(128, 1, 1) +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_lightning_attention_decode.py b/sgl-kernel/tests/test_lightning_attention_decode.py index 74af78e27..f2cace001 100644 --- a/sgl-kernel/tests/test_lightning_attention_decode.py +++ b/sgl-kernel/tests/test_lightning_attention_decode.py @@ -82,3 +82,7 @@ def test_lightning_attention_decode(dtype, batch_size, num_heads, dim, embed_dim msg=f"New KV mismatch for batch_size={batch_size}, num_heads={num_heads}, " f"dim={dim}, embed_dim={embed_dim}, dtype={dtype}", ) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_norm.py b/sgl-kernel/tests/test_norm.py index 32f8c25d9..7b38dba72 100644 --- a/sgl-kernel/tests/test_norm.py +++ b/sgl-kernel/tests/test_norm.py @@ -127,3 +127,7 @@ def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype): torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3) torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3) + + +if __name__ == "__main__": + pytest.main([__file__])