feat: update support for qwen3next model (#10466)
This commit is contained in:
@@ -86,8 +86,8 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
|
||||
b_g = tl.load(p_g).to(tl.float32)
|
||||
|
||||
if USE_QK_L2NORM_IN_KERNEL:
|
||||
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
|
||||
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
|
||||
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6))
|
||||
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6))
|
||||
b_q = b_q * scale
|
||||
# [BK, BV]
|
||||
b_h *= exp(b_g)
|
||||
@@ -411,8 +411,8 @@ def fused_recurrent_gated_delta_rule_update_fwd_kernel(
|
||||
b_g = tl.load(p_g).to(tl.float32)
|
||||
|
||||
if USE_QK_L2NORM_IN_KERNEL:
|
||||
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
|
||||
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
|
||||
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6))
|
||||
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6))
|
||||
b_q = b_q * scale
|
||||
# [BK, BV]
|
||||
b_h *= exp(b_g)
|
||||
|
||||
@@ -119,8 +119,8 @@ def fused_sigmoid_gating_delta_rule_update_kernel(
|
||||
|
||||
# Apply L2 normalization if enabled
|
||||
if USE_QK_L2NORM_IN_KERNEL:
|
||||
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
|
||||
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
|
||||
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6))
|
||||
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6))
|
||||
|
||||
b_q = b_q * scale
|
||||
|
||||
|
||||
@@ -239,6 +239,7 @@ class Qwen3GatedDeltaNet(nn.Module):
|
||||
self,
|
||||
config: Qwen3NextConfig,
|
||||
layer_id: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
alt_stream: Optional[torch.cuda.Stream] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -278,6 +279,7 @@ class Qwen3GatedDeltaNet(nn.Module):
|
||||
input_size=self.hidden_size,
|
||||
output_size=projection_size_qkvz,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
tp_rank=self.attn_tp_rank,
|
||||
tp_size=self.attn_tp_size,
|
||||
)
|
||||
@@ -285,6 +287,7 @@ class Qwen3GatedDeltaNet(nn.Module):
|
||||
input_size=self.hidden_size,
|
||||
output_size=projection_size_ba,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
tp_rank=self.attn_tp_rank,
|
||||
tp_size=self.attn_tp_size,
|
||||
)
|
||||
@@ -336,6 +339,7 @@ class Qwen3GatedDeltaNet(nn.Module):
|
||||
self.value_dim,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
input_is_parallel=True,
|
||||
reduce_results=False,
|
||||
tp_rank=self.attn_tp_rank,
|
||||
@@ -493,7 +497,7 @@ class Qwen3HybridLinearDecoderLayer(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_attn = Qwen3GatedDeltaNet(config, layer_id, alt_stream)
|
||||
self.linear_attn = Qwen3GatedDeltaNet(config, layer_id, quant_config, alt_stream)
|
||||
|
||||
# Qwen3Next all layers are sparse and have no nextn now
|
||||
self.is_layer_sparse = True
|
||||
|
||||
Reference in New Issue
Block a user