diff --git a/sgl-kernel/tests/test_sparse_flash_attn.py b/sgl-kernel/tests/test_sparse_flash_attn.py index 4ddb6d7f5..4aa0a7c19 100644 --- a/sgl-kernel/tests/test_sparse_flash_attn.py +++ b/sgl-kernel/tests/test_sparse_flash_attn.py @@ -10,6 +10,7 @@ from sgl_kernel.sparse_flash_attn import ( sparse_attn_func, sparse_attn_varlen_func, ) +from test_flash_attention import construct_local_mask def ref_attn(