feat: update support for qwen3next model (#10466)
This commit is contained in:
@@ -86,8 +86,8 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
|
|||||||
b_g = tl.load(p_g).to(tl.float32)
|
b_g = tl.load(p_g).to(tl.float32)
|
||||||
|
|
||||||
if USE_QK_L2NORM_IN_KERNEL:
|
if USE_QK_L2NORM_IN_KERNEL:
|
||||||
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
|
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6))
|
||||||
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
|
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6))
|
||||||
b_q = b_q * scale
|
b_q = b_q * scale
|
||||||
# [BK, BV]
|
# [BK, BV]
|
||||||
b_h *= exp(b_g)
|
b_h *= exp(b_g)
|
||||||
@@ -411,8 +411,8 @@ def fused_recurrent_gated_delta_rule_update_fwd_kernel(
|
|||||||
b_g = tl.load(p_g).to(tl.float32)
|
b_g = tl.load(p_g).to(tl.float32)
|
||||||
|
|
||||||
if USE_QK_L2NORM_IN_KERNEL:
|
if USE_QK_L2NORM_IN_KERNEL:
|
||||||
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
|
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6))
|
||||||
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
|
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6))
|
||||||
b_q = b_q * scale
|
b_q = b_q * scale
|
||||||
# [BK, BV]
|
# [BK, BV]
|
||||||
b_h *= exp(b_g)
|
b_h *= exp(b_g)
|
||||||
|
|||||||
@@ -119,8 +119,8 @@ def fused_sigmoid_gating_delta_rule_update_kernel(
|
|||||||
|
|
||||||
# Apply L2 normalization if enabled
|
# Apply L2 normalization if enabled
|
||||||
if USE_QK_L2NORM_IN_KERNEL:
|
if USE_QK_L2NORM_IN_KERNEL:
|
||||||
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
|
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6))
|
||||||
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
|
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6))
|
||||||
|
|
||||||
b_q = b_q * scale
|
b_q = b_q * scale
|
||||||
|
|
||||||
|
|||||||
@@ -239,6 +239,7 @@ class Qwen3GatedDeltaNet(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: Qwen3NextConfig,
|
config: Qwen3NextConfig,
|
||||||
layer_id: int,
|
layer_id: int,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
alt_stream: Optional[torch.cuda.Stream] = None,
|
alt_stream: Optional[torch.cuda.Stream] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -278,6 +279,7 @@ class Qwen3GatedDeltaNet(nn.Module):
|
|||||||
input_size=self.hidden_size,
|
input_size=self.hidden_size,
|
||||||
output_size=projection_size_qkvz,
|
output_size=projection_size_qkvz,
|
||||||
bias=False,
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
tp_rank=self.attn_tp_rank,
|
tp_rank=self.attn_tp_rank,
|
||||||
tp_size=self.attn_tp_size,
|
tp_size=self.attn_tp_size,
|
||||||
)
|
)
|
||||||
@@ -285,6 +287,7 @@ class Qwen3GatedDeltaNet(nn.Module):
|
|||||||
input_size=self.hidden_size,
|
input_size=self.hidden_size,
|
||||||
output_size=projection_size_ba,
|
output_size=projection_size_ba,
|
||||||
bias=False,
|
bias=False,
|
||||||
|
quant_config=None,
|
||||||
tp_rank=self.attn_tp_rank,
|
tp_rank=self.attn_tp_rank,
|
||||||
tp_size=self.attn_tp_size,
|
tp_size=self.attn_tp_size,
|
||||||
)
|
)
|
||||||
@@ -336,6 +339,7 @@ class Qwen3GatedDeltaNet(nn.Module):
|
|||||||
self.value_dim,
|
self.value_dim,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
input_is_parallel=True,
|
input_is_parallel=True,
|
||||||
reduce_results=False,
|
reduce_results=False,
|
||||||
tp_rank=self.attn_tp_rank,
|
tp_rank=self.attn_tp_rank,
|
||||||
@@ -493,7 +497,7 @@ class Qwen3HybridLinearDecoderLayer(nn.Module):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.linear_attn = Qwen3GatedDeltaNet(config, layer_id, alt_stream)
|
self.linear_attn = Qwen3GatedDeltaNet(config, layer_id, quant_config, alt_stream)
|
||||||
|
|
||||||
# Qwen3Next all layers are sparse and have no nextn now
|
# Qwen3Next all layers are sparse and have no nextn now
|
||||||
self.is_layer_sparse = True
|
self.is_layer_sparse = True
|
||||||
|
|||||||
Reference in New Issue
Block a user