diff --git a/python/sglang/srt/layers/attention/fla/fused_recurrent.py b/python/sglang/srt/layers/attention/fla/fused_recurrent.py index fa7262ce2..5e9a0c21e 100644 --- a/python/sglang/srt/layers/attention/fla/fused_recurrent.py +++ b/python/sglang/srt/layers/attention/fla/fused_recurrent.py @@ -86,8 +86,8 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( b_g = tl.load(p_g).to(tl.float32) if USE_QK_L2NORM_IN_KERNEL: - b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6) - b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6) + b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6)) + b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6)) b_q = b_q * scale # [BK, BV] b_h *= exp(b_g) @@ -411,8 +411,8 @@ def fused_recurrent_gated_delta_rule_update_fwd_kernel( b_g = tl.load(p_g).to(tl.float32) if USE_QK_L2NORM_IN_KERNEL: - b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6) - b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6) + b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6)) + b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6)) b_q = b_q * scale # [BK, BV] b_h *= exp(b_g) diff --git a/python/sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py b/python/sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py index 41837b980..feeb7c31c 100644 --- a/python/sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +++ b/python/sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py @@ -119,8 +119,8 @@ def fused_sigmoid_gating_delta_rule_update_kernel( # Apply L2 normalization if enabled if USE_QK_L2NORM_IN_KERNEL: - b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6) - b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6) + b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6)) + b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6)) b_q = b_q * scale diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py index 006ce4f91..245145542 100644 --- a/python/sglang/srt/models/qwen3_next.py +++ b/python/sglang/srt/models/qwen3_next.py @@ -239,6 +239,7 @@ class Qwen3GatedDeltaNet(nn.Module): self, config: Qwen3NextConfig, layer_id: int, + quant_config: Optional[QuantizationConfig] = None, alt_stream: Optional[torch.cuda.Stream] = None, ) -> None: super().__init__() @@ -278,6 +279,7 @@ class Qwen3GatedDeltaNet(nn.Module): input_size=self.hidden_size, output_size=projection_size_qkvz, bias=False, + quant_config=quant_config, tp_rank=self.attn_tp_rank, tp_size=self.attn_tp_size, ) @@ -285,6 +287,7 @@ class Qwen3GatedDeltaNet(nn.Module): input_size=self.hidden_size, output_size=projection_size_ba, bias=False, + quant_config=None, tp_rank=self.attn_tp_rank, tp_size=self.attn_tp_size, ) @@ -336,6 +339,7 @@ class Qwen3GatedDeltaNet(nn.Module): self.value_dim, self.hidden_size, bias=False, + quant_config=quant_config, input_is_parallel=True, reduce_results=False, tp_rank=self.attn_tp_rank, @@ -493,7 +497,7 @@ class Qwen3HybridLinearDecoderLayer(nn.Module): ) -> None: super().__init__() self.config = config - self.linear_attn = Qwen3GatedDeltaNet(config, layer_id, alt_stream) + self.linear_attn = Qwen3GatedDeltaNet(config, layer_id, quant_config, alt_stream) # Qwen3Next all layers are sparse and have no nextn now self.is_layer_sparse = True