Fix FA4 import cause moe_fused_gate output be illegal memory (#10368)
This commit is contained in:
@@ -9,11 +9,6 @@ try:
|
|||||||
except:
|
except:
|
||||||
raise ImportError("Can not import sgl_kernel. Please check your installation.")
|
raise ImportError("Can not import sgl_kernel. Please check your installation.")
|
||||||
|
|
||||||
try:
|
|
||||||
from ._fa4_interface import flash_attn_varlen_func as flash_attn_varlen_func_v4
|
|
||||||
except ImportError:
|
|
||||||
flash_attn_varlen_func_v4 = None
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
@lru_cache(maxsize=1)
|
||||||
def is_fa3_supported(device=None) -> bool:
|
def is_fa3_supported(device=None) -> bool:
|
||||||
@@ -249,9 +244,8 @@ def flash_attn_varlen_func(
|
|||||||
ver=3,
|
ver=3,
|
||||||
):
|
):
|
||||||
if ver == 4:
|
if ver == 4:
|
||||||
assert (
|
from ._fa4_interface import flash_attn_varlen_func as flash_attn_varlen_func_v4
|
||||||
flash_attn_varlen_func_v4 is not None
|
|
||||||
), "FA4 is not available, please check your installation."
|
|
||||||
# Using `(-1, -1)` as no sliding window causes correctness issues for FA4.
|
# Using `(-1, -1)` as no sliding window causes correctness issues for FA4.
|
||||||
if window_size == (-1, -1):
|
if window_size == (-1, -1):
|
||||||
window_size = (None, None)
|
window_size = (None, None)
|
||||||
|
|||||||
Reference in New Issue
Block a user