feat: update support for qwen3next model (#10466)

This commit is contained in:
cao1zhg
2025-09-16 16:09:56 +08:00
committed by GitHub
parent b2435be682
commit b6dd4bcb81
3 changed files with 11 additions and 7 deletions

View File

@@ -86,8 +86,8 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
b_g = tl.load(p_g).to(tl.float32)
if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6))
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6))
b_q = b_q * scale
# [BK, BV]
b_h *= exp(b_g)
@@ -411,8 +411,8 @@ def fused_recurrent_gated_delta_rule_update_fwd_kernel(
b_g = tl.load(p_g).to(tl.float32)
if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6))
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6))
b_q = b_q * scale
# [BK, BV]
b_h *= exp(b_g)

View File

@@ -119,8 +119,8 @@ def fused_sigmoid_gating_delta_rule_update_kernel(
# Apply L2 normalization if enabled
if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6))
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6))
b_q = b_q * scale

View File

@@ -239,6 +239,7 @@ class Qwen3GatedDeltaNet(nn.Module):
self,
config: Qwen3NextConfig,
layer_id: int,
quant_config: Optional[QuantizationConfig] = None,
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None:
super().__init__()
@@ -278,6 +279,7 @@ class Qwen3GatedDeltaNet(nn.Module):
input_size=self.hidden_size,
output_size=projection_size_qkvz,
bias=False,
quant_config=quant_config,
tp_rank=self.attn_tp_rank,
tp_size=self.attn_tp_size,
)
@@ -285,6 +287,7 @@ class Qwen3GatedDeltaNet(nn.Module):
input_size=self.hidden_size,
output_size=projection_size_ba,
bias=False,
quant_config=None,
tp_rank=self.attn_tp_rank,
tp_size=self.attn_tp_size,
)
@@ -336,6 +339,7 @@ class Qwen3GatedDeltaNet(nn.Module):
self.value_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
input_is_parallel=True,
reduce_results=False,
tp_rank=self.attn_tp_rank,
@@ -493,7 +497,7 @@ class Qwen3HybridLinearDecoderLayer(nn.Module):
) -> None:
super().__init__()
self.config = config
self.linear_attn = Qwen3GatedDeltaNet(config, layer_id, alt_stream)
self.linear_attn = Qwen3GatedDeltaNet(config, layer_id, quant_config, alt_stream)
# Qwen3Next all layers are sparse and have no nextn now
self.is_layer_sparse = True