From d88ef4a388324a62129866bb68a1aeeb263fe1c2 Mon Sep 17 00:00:00 2001 From: Jinyan Chen <93358689+liz-badada@users.noreply.github.com> Date: Sat, 20 Sep 2025 07:59:37 +0800 Subject: [PATCH] limit sgl-kernel causal conv1d to cuda only (#10648) Co-authored-by: Jinyan Chen --- .../attention/hybrid_linear_attn_backend.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py index 699e5af68..19e9bd516 100644 --- a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py @@ -13,9 +13,6 @@ from sglang.srt.layers.attention.fla.fused_recurrent import ( from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import ( fused_sigmoid_gating_delta_rule_update, ) -from sglang.srt.layers.attention.mamba.causal_conv1d import ( - causal_conv1d_fn as causal_conv1d_fn_sgl, -) from sglang.srt.layers.attention.mamba.causal_conv1d_triton import ( causal_conv1d_fn, causal_conv1d_update, @@ -26,9 +23,15 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.models.qwen3_next import Qwen3HybridLinearDecoderLayer, fused_gdn_gating from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput -from sglang.srt.utils import is_npu +from sglang.srt.utils import is_cuda, is_npu -if is_npu(): +if is_cuda(): + from sglang.srt.layers.attention.mamba.causal_conv1d import ( + causal_conv1d_fn as causal_conv1d_fn_cuda, + ) + + causal_conv1d_fn = causal_conv1d_fn_cuda +elif is_npu(): from sgl_kernel_npu.fla.chunk import chunk_gated_delta_rule_npu from sgl_kernel_npu.fla.fused_sigmoid_gating_recurrent import ( fused_sigmoid_gating_delta_rule_update_npu, @@ -350,7 +353,7 @@ class MambaAttnBackend(AttentionBackend): mixed_qkv_processed.transpose(1, 2).contiguous().view(seq_len, -1) ) else: - mixed_qkv = causal_conv1d_fn_sgl( + mixed_qkv = causal_conv1d_fn( mixed_qkv.transpose(0, 1), conv_weights, bias,