move fla env check position (#11500)
This commit is contained in:
@@ -183,6 +183,7 @@ def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBac
|
|||||||
), "hybrid_gdn can only be used with non-MLA models."
|
), "hybrid_gdn can only be used with non-MLA models."
|
||||||
|
|
||||||
if cfg := runner.mambaish_config:
|
if cfg := runner.mambaish_config:
|
||||||
|
from sglang.srt.layers.attention.fla.utils import check_environments
|
||||||
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
|
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
|
||||||
GDNAttnBackend,
|
GDNAttnBackend,
|
||||||
HybridLinearAttnBackend,
|
HybridLinearAttnBackend,
|
||||||
@@ -190,6 +191,7 @@ def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBac
|
|||||||
)
|
)
|
||||||
from sglang.srt.utils import is_blackwell, is_npu
|
from sglang.srt.utils import is_blackwell, is_npu
|
||||||
|
|
||||||
|
check_environments()
|
||||||
if runner.hybrid_gdn_config is not None:
|
if runner.hybrid_gdn_config is not None:
|
||||||
if is_blackwell():
|
if is_blackwell():
|
||||||
assert (
|
assert (
|
||||||
|
|||||||
@@ -58,9 +58,6 @@ def check_environments():
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
check_environments()
|
|
||||||
|
|
||||||
|
|
||||||
def get_abs_err(x, y):
|
def get_abs_err(x, y):
|
||||||
return (x.detach() - y.detach()).flatten().abs().max().item()
|
return (x.detach() - y.detach()).flatten().abs().max().item()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user