diff --git a/sgl-kernel/python/sgl_kernel/flash_attn.py b/sgl-kernel/python/sgl_kernel/flash_attn.py index 33e959703..270d00d44 100644 --- a/sgl-kernel/python/sgl_kernel/flash_attn.py +++ b/sgl-kernel/python/sgl_kernel/flash_attn.py @@ -9,11 +9,6 @@ try: except: raise ImportError("Can not import sgl_kernel. Please check your installation.") -try: - from ._fa4_interface import flash_attn_varlen_func as flash_attn_varlen_func_v4 -except ImportError: - flash_attn_varlen_func_v4 = None - @lru_cache(maxsize=1) def is_fa3_supported(device=None) -> bool: @@ -249,9 +244,8 @@ def flash_attn_varlen_func( ver=3, ): if ver == 4: - assert ( - flash_attn_varlen_func_v4 is not None - ), "FA4 is not available, please check your installation." + from ._fa4_interface import flash_attn_varlen_func as flash_attn_varlen_func_v4 + # Using `(-1, -1)` as no sliding window causes correctness issues for FA4. if window_size == (-1, -1): window_size = (None, None)