Files
sglang/python/sglang/srt/models/qwen3_next.py
2025-10-21 00:17:02 -07:00

1065 lines
36 KiB
Python

import enum
import logging
from typing import Any, Iterable, Optional, Set, Tuple
import torch
from torch import nn
from sglang.srt.configs.qwen3_next import Qwen3NextConfig
from sglang.srt.distributed import divide, get_pp_group
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.layers.attention.fla.layernorm_gated import RMSNorm as RMSNormGated
from sglang.srt.layers.attention.mamba.mamba import mamba_v2_sharded_weight_loader
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
is_dp_attention_enabled,
)
from sglang.srt.layers.layernorm import GemmaRMSNorm
from sglang.srt.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
sharded_weight_loader,
)
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP, Qwen2MoeSparseMoeBlock
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import (
LazyValue,
add_prefix,
is_cuda,
is_npu,
make_layers,
set_weight_attrs,
)
logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
_is_npu = is_npu()
import triton
import triton.language as tl
@triton.jit
def fused_qkvzba_split_reshape_cat_kernel(
mixed_qkv,
z,
b,
a,
mixed_qkvz,
mixed_ba,
NUM_HEADS_QK: tl.constexpr,
NUM_HEADS_V: tl.constexpr,
HEAD_QK: tl.constexpr,
HEAD_V: tl.constexpr,
):
i_bs, i_qk = tl.program_id(0), tl.program_id(1)
QKVZ_DIM_T: tl.constexpr = HEAD_QK * 2 + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V * 2
BA_DIM_T: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK * 2
QKV_DIM_T: tl.constexpr = HEAD_QK * 2 + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V
q_end: tl.constexpr = HEAD_QK
blk_q_ptr = (
mixed_qkvz
+ i_bs * NUM_HEADS_QK * QKVZ_DIM_T
+ i_qk * QKVZ_DIM_T
+ tl.arange(0, q_end)
)
k_end: tl.constexpr = q_end + HEAD_QK
blk_k_ptr = (
mixed_qkvz
+ i_bs * NUM_HEADS_QK * QKVZ_DIM_T
+ i_qk * QKVZ_DIM_T
+ tl.arange(q_end, k_end)
)
v_end: tl.constexpr = k_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V
blk_v_ptr = (
mixed_qkvz
+ i_bs * NUM_HEADS_QK * QKVZ_DIM_T
+ i_qk * QKVZ_DIM_T
+ tl.arange(k_end, v_end)
)
z_end: tl.constexpr = v_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V
blk_z_ptr = (
mixed_qkvz
+ i_bs * NUM_HEADS_QK * QKVZ_DIM_T
+ i_qk * QKVZ_DIM_T
+ tl.arange(v_end, z_end)
)
blk_q_st_ptr = (
mixed_qkv
+ i_bs * NUM_HEADS_QK * QKV_DIM_T
+ i_qk * HEAD_QK
+ tl.arange(0, HEAD_QK)
)
blk_k_st_ptr = (
mixed_qkv
+ i_bs * NUM_HEADS_QK * QKV_DIM_T
+ NUM_HEADS_QK * HEAD_QK
+ i_qk * HEAD_QK
+ tl.arange(0, HEAD_QK)
)
blk_v_st_ptr = (
mixed_qkv
+ i_bs * NUM_HEADS_QK * QKV_DIM_T
+ NUM_HEADS_QK * HEAD_QK * 2
+ i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK
+ tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK)
)
blk_z_st_ptr = (
z
+ i_bs * NUM_HEADS_V * HEAD_V
+ i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK
+ tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK)
)
tl.store(blk_q_st_ptr, tl.load(blk_q_ptr))
tl.store(blk_k_st_ptr, tl.load(blk_k_ptr))
tl.store(blk_v_st_ptr, tl.load(blk_v_ptr))
tl.store(blk_z_st_ptr, tl.load(blk_z_ptr))
b_end: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK
a_end: tl.constexpr = b_end + NUM_HEADS_V // NUM_HEADS_QK
for i in tl.static_range(b_end):
blk_b_ptr = mixed_ba + i_bs * NUM_HEADS_QK * BA_DIM_T + i_qk * BA_DIM_T + i
blk_b_st_ptr = b + i_bs * NUM_HEADS_V + i_qk * NUM_HEADS_V // NUM_HEADS_QK + i
tl.store(blk_b_st_ptr, tl.load(blk_b_ptr))
for i in tl.static_range(b_end, a_end):
blk_a_ptr = mixed_ba + i_bs * NUM_HEADS_QK * BA_DIM_T + i_qk * BA_DIM_T + i
blk_a_st_ptr = (
a + i_bs * NUM_HEADS_V + i_qk * NUM_HEADS_V // NUM_HEADS_QK + (i - b_end)
)
tl.store(blk_a_st_ptr, tl.load(blk_a_ptr))
def fused_qkvzba_split_reshape_cat(
mixed_qkvz,
mixed_ba,
num_heads_qk,
num_heads_v,
head_qk,
head_v,
):
batch, seq_len = mixed_qkvz.shape[0], 1
qkv_dim_t = num_heads_qk * head_qk * 2 + num_heads_v * head_v
mixed_qkv = torch.empty(
[batch * seq_len, qkv_dim_t],
dtype=mixed_qkvz.dtype,
device=mixed_qkvz.device,
)
z = torch.empty(
[batch * seq_len, num_heads_v, head_v],
dtype=mixed_qkvz.dtype,
device=mixed_qkvz.device,
)
b = torch.empty(
[batch * seq_len, num_heads_v],
dtype=mixed_ba.dtype,
device=mixed_ba.device,
)
a = torch.empty_like(b)
grid = (batch * seq_len, num_heads_qk)
fused_qkvzba_split_reshape_cat_kernel[grid](
mixed_qkv,
z,
b,
a,
mixed_qkvz,
mixed_ba,
num_heads_qk,
num_heads_v,
head_qk,
head_v,
num_warps=1,
num_stages=3,
)
return mixed_qkv, z, b, a
# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
@triton.jit
def fused_gdn_gating_kernel(
g,
A_log,
a,
dt_bias,
seq_len,
NUM_HEADS: tl.constexpr,
beta: tl.constexpr,
threshold: tl.constexpr,
BLK_HEADS: tl.constexpr,
):
i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2)
head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS)
off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off
mask = head_off < NUM_HEADS
blk_A_log = tl.load(A_log + head_off, mask=mask)
blk_a = tl.load(a + off, mask=mask)
blk_bias = tl.load(dt_bias + head_off, mask=mask)
x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)
softplus_x = tl.where(
beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x
)
blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
def fused_gdn_gating(
A_log: torch.Tensor,
a: torch.Tensor,
dt_bias: torch.Tensor,
beta: float = 1.0,
threshold: float = 20.0,
) -> torch.Tensor:
batch, num_heads = a.shape
seq_len = 1
grid = (batch, seq_len, triton.cdiv(num_heads, 8))
g = torch.empty_like(a, dtype=torch.float32)
fused_gdn_gating_kernel[grid](
g, A_log, a, dt_bias, seq_len, num_heads, beta, threshold, 8, num_warps=1
)
return g
class Qwen3GatedDeltaNet(nn.Module):
def __init__(
self,
config: Qwen3NextConfig,
layer_id: int,
quant_config: Optional[QuantizationConfig] = None,
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None:
super().__init__()
self.config = config
self.attn_tp_rank = get_attention_tp_rank()
self.attn_tp_size = get_attention_tp_size()
self.hidden_size = config.hidden_size
self.num_v_heads = config.linear_num_value_heads
self.num_k_heads = config.linear_num_key_heads
self.head_k_dim = config.linear_key_head_dim
self.head_v_dim = config.linear_value_head_dim
self.key_dim = self.head_k_dim * self.num_k_heads
self.value_dim = self.head_v_dim * self.num_v_heads
self.alt_stream = alt_stream
self.conv_kernel_size = config.linear_conv_kernel_dim
self.layer_id = layer_id
self.activation = config.hidden_act
self.layer_norm_epsilon = config.rms_norm_eps
# QKV
self.conv_dim = self.key_dim * 2 + self.value_dim
self.conv1d = ColumnParallelLinear(
input_size=self.conv_kernel_size,
output_size=self.conv_dim,
bias=False,
quant_config=None,
tp_rank=self.attn_tp_rank,
tp_size=self.attn_tp_size,
)
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
# projection of the input hidden states
projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2
projection_size_ba = self.num_v_heads * 2
self.in_proj_qkvz = ColumnParallelLinear(
input_size=self.hidden_size,
output_size=projection_size_qkvz,
bias=False,
quant_config=quant_config,
tp_rank=self.attn_tp_rank,
tp_size=self.attn_tp_size,
)
self.in_proj_ba = ColumnParallelLinear(
input_size=self.hidden_size,
output_size=projection_size_ba,
bias=False,
quant_config=None,
tp_rank=self.attn_tp_rank,
tp_size=self.attn_tp_size,
)
query_key_settings = (self.key_dim, 0, False)
value_settings = (self.value_dim, 0, False)
delattr(self.conv1d.weight, "weight_loader")
set_weight_attrs(
self.conv1d.weight,
{
"weight_loader": mamba_v2_sharded_weight_loader(
[
query_key_settings,
query_key_settings,
value_settings,
],
self.attn_tp_size,
self.attn_tp_rank,
)
},
)
# selective projection used to make dt, B and C input dependent
# time step projection (discretization)
# instantiate once and copy inv_dt in init_weights of PretrainedModel
self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads // self.attn_tp_size))
A = torch.empty(
divide(self.num_v_heads, self.attn_tp_size), dtype=torch.float32
).uniform_(0, 16)
self.A_log = nn.Parameter(torch.log(A))
self.A_log._no_weight_decay = True
set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)})
set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
self.norm = RMSNormGated(
self.head_v_dim,
eps=self.layer_norm_epsilon,
group_size=None,
norm_before_gate=True,
device=torch.get_device_module().current_device(),
dtype=config.torch_dtype,
)
self.out_proj = RowParallelLinear(
self.value_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
input_is_parallel=True,
reduce_results=False,
tp_rank=self.attn_tp_rank,
tp_size=self.attn_tp_size,
)
def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba):
"""
Derives `query`, `key` and `value` tensors from `mixed_qkvzba`.
"""
new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + (
self.num_k_heads // self.attn_tp_size,
(
self.head_k_dim
+ self.head_k_dim
+ (self.head_v_dim + self.head_v_dim)
* self.num_v_heads
// self.num_k_heads
),
)
new_tensor_shape_ba = mixed_ba.size()[:-1] + (
self.num_k_heads // self.attn_tp_size,
2 * self.num_v_heads // self.num_k_heads,
)
mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz)
mixed_ba = mixed_ba.view(*new_tensor_shape_ba)
split_arg_list_qkvz = [
self.head_k_dim,
self.head_k_dim,
(self.num_v_heads // self.num_k_heads * self.head_v_dim),
(self.num_v_heads // self.num_k_heads * self.head_v_dim),
]
split_arg_list_ba = [
self.num_v_heads // self.num_k_heads,
self.num_v_heads // self.num_k_heads,
]
# [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)]
# --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng]
(query, key, value, z) = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=2)
(b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=2)
# [b, sq, ng, np/ng * hn] -> [b, sq, np, hn]
value = value.reshape(value.size(0), -1, self.head_v_dim)
z = z.reshape(z.size(0), -1, self.head_v_dim)
b = b.reshape(b.size(0), self.num_v_heads // self.attn_tp_size)
a = a.reshape(a.size(0), self.num_v_heads // self.attn_tp_size)
return query, key, value, z, b, a
def _forward_input_proj(self, hidden_states: torch.Tensor):
DUAL_STREAM_TOKEN_THRESHOLD = 1024 if not _is_npu else 0
seq_len, _ = hidden_states.shape
if seq_len < DUAL_STREAM_TOKEN_THRESHOLD:
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)
with torch.cuda.stream(self.alt_stream):
projected_states_ba, _ = self.in_proj_ba(hidden_states)
current_stream.wait_stream(self.alt_stream)
else:
projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)
projected_states_ba, _ = self.in_proj_ba(hidden_states)
return projected_states_qkvz, projected_states_ba
def forward(
self,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
):
seq_len, _ = hidden_states.shape
is_cuda_graph = forward_batch.forward_mode.is_cuda_graph()
projected_states_qkvz, projected_states_ba = self._forward_input_proj(
hidden_states
)
if self.num_v_heads // self.num_k_heads in [1, 2, 4] and is_cuda_graph:
mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat(
projected_states_qkvz,
projected_states_ba,
triton.cdiv(self.num_k_heads, self.attn_tp_size),
triton.cdiv(self.num_v_heads, self.attn_tp_size),
self.head_k_dim,
self.head_v_dim,
)
else:
query, key, value, z, b, a = self.fix_query_key_value_ordering(
projected_states_qkvz, projected_states_ba
)
query, key, value = map(
lambda x: x.reshape(x.shape[0], -1), (query, key, value)
)
mixed_qkv = torch.cat((query, key, value), dim=-1)
# mixed_qkv = rearrange(mixed_qkv, "b l d -> b d l")
# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
)
kwargs = {
"mixed_qkv": mixed_qkv,
"conv_weights": conv_weights,
"bias": self.conv1d.bias,
"activation": self.activation,
"key_dim": self.key_dim,
"value_dim": self.value_dim,
"attention_tp_size": self.attn_tp_size,
"head_k_dim": self.head_k_dim,
"head_v_dim": self.head_v_dim,
"a": a,
"b": b,
"A_log": self.A_log,
"dt_bias": self.dt_bias,
"layer_id": self.layer_id,
"seq_len": seq_len,
"num_k_heads": self.num_k_heads,
"num_v_heads": self.num_v_heads,
"z": z,
}
core_attn_out = forward_batch.attn_backend.forward(
q=None,
k=None,
v=None,
layer=None,
forward_batch=forward_batch,
**kwargs,
)
z_shape_og = z.shape
# reshape input data into 2D tensor
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
z = z.reshape(-1, z.shape[-1])
core_attn_out = self.norm(core_attn_out, z)
core_attn_out = core_attn_out.reshape(z_shape_og)
core_attn_out = core_attn_out.reshape(*core_attn_out.shape[:-2], -1)
output, _ = self.out_proj(core_attn_out)
return output
class Qwen3HybridLinearDecoderLayer(nn.Module):
def __init__(
self,
config: Qwen3NextConfig,
layer_id: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None:
super().__init__()
self.config = config
self.linear_attn = Qwen3GatedDeltaNet(
config, layer_id, quant_config, alt_stream
)
# Qwen3Next all layers are sparse and have no nextn now
self.is_layer_sparse = True
is_previous_layer_sparse = True
self.layer_id = layer_id
self.layer_scatter_modes = LayerScatterModes.init_new(
layer_id=layer_id,
num_layers=config.num_hidden_layers,
is_layer_sparse=self.is_layer_sparse,
is_previous_layer_sparse=is_previous_layer_sparse,
)
if self.is_layer_sparse:
self.mlp = Qwen2MoeSparseMoeBlock(
layer_id=layer_id,
config=config,
quant_config=quant_config,
alt_stream=alt_stream,
prefix=add_prefix("mlp", prefix),
)
else:
self.mlp = Qwen2MoeMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
)
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = GemmaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.layer_communicator = LayerCommunicator(
layer_scatter_modes=self.layer_scatter_modes,
input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm,
allow_reduce_scatter=True,
)
def forward(
self,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
**kwargs,
):
forward_batch = kwargs.get("forward_batch", None)
hidden_states, residual = self.layer_communicator.prepare_attn(
hidden_states, residual, forward_batch
)
if not forward_batch.forward_mode.is_idle():
hidden_states = self.linear_attn(
hidden_states,
forward_batch,
)
# Fully Connected
hidden_states, residual = self.layer_communicator.prepare_mlp(
hidden_states, residual, forward_batch
)
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
forward_batch
)
hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
hidden_states, residual = self.layer_communicator.postprocess_layer(
hidden_states, residual, forward_batch
)
return hidden_states, residual
class Qwen3HybridAttentionDecoderLayer(nn.Module):
def __init__(
self,
config: Qwen3NextConfig,
layer_id: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.attn_tp_rank = get_attention_tp_rank()
self.attn_tp_size = get_attention_tp_size()
self.total_num_heads = config.num_attention_heads
assert self.total_num_heads % self.attn_tp_size == 0
self.num_heads = self.total_num_heads // self.attn_tp_size
self.total_num_kv_heads = config.num_key_value_heads
if self.total_num_kv_heads >= self.attn_tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % self.attn_tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert self.attn_tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // self.attn_tp_size)
self.head_dim = config.head_dim or (self.hidden_size // self.num_heads)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = getattr(config, "rope_theta", 10000)
self.max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.rope_scaling = getattr(config, "rope_scaling", None)
self.partial_rotary_factor = config.partial_rotary_factor
self.layer_id = layer_id
self.attn_output_gate = getattr(config, "attn_output_gate", True)
if self.attn_output_gate:
logger.warning_once("using attn output gate!")
self.rotary_emb = get_rope(
head_size=self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
rope_scaling=self.rope_scaling,
base=self.rope_theta,
partial_rotary_factor=self.partial_rotary_factor,
is_neox_style=True,
dtype=torch.get_default_dtype(), # see impl of get_rope
)
self.qkv_proj = QKVParallelLinear(
config.hidden_size,
self.head_dim,
self.total_num_heads * (1 + self.attn_output_gate),
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
tp_rank=self.attn_tp_rank,
tp_size=self.attn_tp_size,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
config.hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=False,
tp_rank=self.attn_tp_rank,
tp_size=self.attn_tp_size,
)
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
prefix=f"{prefix}.attn",
)
# Qwen3Next all layers are sparse and have no nextn now
self.is_layer_sparse = True
is_previous_layer_sparse = True
self.layer_scatter_modes = LayerScatterModes.init_new(
layer_id=layer_id,
num_layers=config.num_hidden_layers,
is_layer_sparse=self.is_layer_sparse,
is_previous_layer_sparse=is_previous_layer_sparse,
)
if self.is_layer_sparse:
self.mlp = Qwen2MoeSparseMoeBlock(
layer_id=layer_id,
config=config,
quant_config=quant_config,
alt_stream=alt_stream,
prefix=add_prefix("mlp", prefix),
)
else:
self.mlp = Qwen2MoeMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
)
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = GemmaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.layer_communicator = LayerCommunicator(
layer_scatter_modes=self.layer_scatter_modes,
input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm,
allow_reduce_scatter=True,
)
self.alt_stream = alt_stream
def _apply_qk_norm(
self, q: torch.Tensor, k: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# overlap qk norm
if self.alt_stream is not None and get_is_capture_mode():
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
q_by_head = q.reshape(-1, self.head_dim)
q_by_head = self.q_norm(q_by_head)
with torch.cuda.stream(self.alt_stream):
k_by_head = k.reshape(-1, self.head_dim)
k_by_head = self.k_norm(k_by_head)
current_stream.wait_stream(self.alt_stream)
else:
q_by_head = q.reshape(-1, self.head_dim)
q_by_head = self.q_norm(q_by_head)
k_by_head = k.reshape(-1, self.head_dim)
k_by_head = self.k_norm(k_by_head)
q = q_by_head.view(q.shape)
k = k_by_head.view(k.shape)
return q, k
def self_attention(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
if self.attn_output_gate:
q_gate, k, v = qkv.split(
[self.q_size * 2, self.kv_size, self.kv_size], dim=-1
)
orig_shape = q_gate.shape[:-1]
q_gate = q_gate.view(*orig_shape, self.num_heads, -1)
q, gate = torch.chunk(q_gate, 2, dim=-1)
q = q.reshape(*orig_shape, -1)
gate = gate.reshape(*orig_shape, -1)
else:
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, forward_batch)
if self.attn_output_gate:
gate = torch.sigmoid(gate)
attn_output = attn_output * gate
output, _ = self.o_proj(attn_output)
return output
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
forward_batch: ForwardBatch,
**kwargs: Any,
):
hidden_states, residual = self.layer_communicator.prepare_attn(
hidden_states, residual, forward_batch
)
if not forward_batch.forward_mode.is_idle():
hidden_states = self.self_attention(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
# Fully Connected
hidden_states, residual = self.layer_communicator.prepare_mlp(
hidden_states, residual, forward_batch
)
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
forward_batch
)
hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
hidden_states, residual = self.layer_communicator.postprocess_layer(
hidden_states, residual, forward_batch
)
return hidden_states, residual
ALL_DECODER_LAYER_TYPES = {
"attention": Qwen3HybridAttentionDecoderLayer,
"linear_attention": Qwen3HybridLinearDecoderLayer,
}
class Qwen3NextModel(nn.Module):
def __init__(
self,
config: Qwen3NextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
alt_stream = torch.cuda.Stream() if _is_cuda else None
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
enable_tp=not is_dp_attention_enabled(),
)
def get_layer(idx: int, prefix: str):
layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[idx]]
return layer_class(
config,
idx,
quant_config=quant_config,
prefix=prefix,
alt_stream=alt_stream,
)
self.layers = make_layers(
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers"
)
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.infer_count = 0
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
# mamba_cache_params: MambaCacheParams,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# pass a sequence index tensor, that is required for
# proper continuous batching computation including
# chunked prefill
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
with get_global_expert_distribution_recorder().with_current_layer(i):
hidden_states, residual = layer(
layer_id=i,
positions=positions,
hidden_states=hidden_states,
residual=residual,
forward_batch=forward_batch,
)
if not forward_batch.forward_mode.is_idle():
if residual is None:
hidden_states = self.norm(hidden_states)
else:
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class HybridLayerType(enum.Enum):
full_attention = "attention"
swa_attention = "swa_attention"
linear_attention = "linear_attention"
mamba2 = "mamba"
class Qwen3NextForCausalLM(nn.Module):
fall_back_to_pt_during_load = False
def __init__(
self,
config: Qwen3NextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.pp_group = get_pp_group()
assert self.pp_group.is_first_rank and self.pp_group.is_last_rank
self.quant_config = quant_config
self.model = Qwen3NextModel(
config, quant_config, prefix=add_prefix("model", prefix)
)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
org_num_embeddings=config.vocab_size,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
)
self.lm_head = self.lm_head.float()
self.logits_processor = LogitsProcessor(config)
self._routed_experts_weights_of_layer = LazyValue(
lambda: {
layer_id: layer.mlp.get_moe_weights()
for layer_id, layer in enumerate(self.model.layers)
if isinstance(layer.mlp, Qwen2MoeSparseMoeBlock)
}
)
@property
def routed_experts_weights_of_layer(self):
return self._routed_experts_weights_of_layer.value
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
):
hidden_states = self.model(input_ids, positions, forward_batch, inputs_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
def get_embed_and_head(self):
return self.model.embed_tokens.weight, self.lm_head.weight
def set_embed_and_head(self, embed, head):
del self.model.embed_tokens.weight
del self.lm_head.weight
self.model.embed_tokens.weight = embed
self.lm_head.weight = head
torch.cuda.empty_cache()
torch.cuda.synchronize()
def load_weights(
self, weights: Iterable[Tuple[str, torch.Tensor]], is_mtp: bool = False
) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts,
)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if is_mtp:
if "mtp" not in name:
continue
if name in [
"mtp.fc.weight",
"mtp.pre_fc_norm_embedding.weight",
"mtp.pre_fc_norm_hidden.weight",
]:
name = name.replace("mtp.", "")
else:
name = name.replace("mtp", "model")
if not is_mtp and "mtp" in name:
continue
if "rotary_emb.inv_freq" in name:
continue
if ".self_attn." in name:
name = name.replace(".self_attn", "")
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
# TODO(fix mtp loading)
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip layers on other devices.
# if is_pp_missing_parameter(name, self):
# continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader")
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip layers on other devices.
# if is_pp_missing_parameter(name, self):
# continue
# Skip loading extra bias for GPTQ models.
if (
name.endswith(".bias") or name.endswith("_bias")
) and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader")
weight_loader(
param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# if is_pp_missing_parameter(name, self):
# continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
@classmethod
def get_model_config_for_expert_location(cls, config):
return ModelConfigForExpertLocation(
num_layers=config.num_hidden_layers,
num_logical_experts=config.num_experts,
num_groups=None,
)
EntryClass = Qwen3NextForCausalLM