diff --git a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py index 699e5af68..19e9bd516 100644 --- a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py @@ -13,9 +13,6 @@ from sglang.srt.layers.attention.fla.fused_recurrent import ( from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import ( fused_sigmoid_gating_delta_rule_update, ) -from sglang.srt.layers.attention.mamba.causal_conv1d import ( - causal_conv1d_fn as causal_conv1d_fn_sgl, -) from sglang.srt.layers.attention.mamba.causal_conv1d_triton import ( causal_conv1d_fn, causal_conv1d_update, @@ -26,9 +23,15 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.models.qwen3_next import Qwen3HybridLinearDecoderLayer, fused_gdn_gating from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput -from sglang.srt.utils import is_npu +from sglang.srt.utils import is_cuda, is_npu -if is_npu(): +if is_cuda(): + from sglang.srt.layers.attention.mamba.causal_conv1d import ( + causal_conv1d_fn as causal_conv1d_fn_cuda, + ) + + causal_conv1d_fn = causal_conv1d_fn_cuda +elif is_npu(): from sgl_kernel_npu.fla.chunk import chunk_gated_delta_rule_npu from sgl_kernel_npu.fla.fused_sigmoid_gating_recurrent import ( fused_sigmoid_gating_delta_rule_update_npu, @@ -350,7 +353,7 @@ class MambaAttnBackend(AttentionBackend): mixed_qkv_processed.transpose(1, 2).contiguous().view(seq_len, -1) ) else: - mixed_qkv = causal_conv1d_fn_sgl( + mixed_qkv = causal_conv1d_fn( mixed_qkv.transpose(0, 1), conv_weights, bias,