Files
enginex-vllm-bi100-qwen36/qwen3_6_scripts/qwen3_5.py

1271 lines
56 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Inference-only Qwen3.6-27B (Qwen3_5 architecture) for Iluvatar BI-V100.
# Pure-PyTorch DeltaNet (no fla / causal_conv1d dependency).
# Text-only (no VL, no MTP).
from typing import Iterable, List, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.mamba_cache import MambaCacheManager
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
_get_graph_batch_size)
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import HasInnerState, SupportsLoRA
logger = init_logger(__name__)
# ---------------------------------------------------------------------------
# Pure-PyTorch DeltaNet kernels (fallbacks from transformers 5.2.0)
# ---------------------------------------------------------------------------
def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor:
return x * torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
def _torch_causal_conv1d_update(
hidden_states: torch.Tensor, # (batch, channels, seq=1)
conv_state: torch.Tensor, # (batch, channels, state_len) modified in-place
weight: torch.Tensor, # (channels, kernel_size)
bias: Optional[torch.Tensor] = None,
activation: Optional[str] = None,
) -> torch.Tensor:
_, channels, seq_len = hidden_states.shape
state_len = conv_state.shape[-1]
cat = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype)
conv_state.copy_(cat[:, :, -state_len:])
out = F.conv1d(cat, weight.unsqueeze(1), bias, padding=0, groups=channels)
out = out[:, :, -seq_len:]
if activation is not None:
out = F.silu(out)
return out.to(hidden_states.dtype)
def _torch_chunk_gated_delta_rule(
query: torch.Tensor, # (batch, seq, num_heads, head_k_dim)
key: torch.Tensor,
value: torch.Tensor, # (batch, seq, num_heads, head_v_dim)
g: torch.Tensor, # (batch, seq, num_heads)
beta: torch.Tensor, # (batch, seq, num_heads)
chunk_size: int = 64,
initial_state: Optional[torch.Tensor] = None,
output_final_state: bool = False,
use_qk_l2norm_in_kernel: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
initial_dtype = query.dtype
if use_qk_l2norm_in_kernel:
query = _l2norm(query)
key = _l2norm(key)
# Transpose to (batch, num_heads, seq, dim)
query, key, value, beta, g = [
x.transpose(1, 2).contiguous().to(torch.float32)
for x in (query, key, value, beta, g)
]
batch, num_heads, seq_len, k_dim = key.shape
v_dim = value.shape[-1]
pad = (chunk_size - seq_len % chunk_size) % chunk_size
query = F.pad(query, (0, 0, 0, pad))
key = F.pad(key, (0, 0, 0, pad))
value = F.pad(value, (0, 0, 0, pad))
beta = F.pad(beta, (0, pad))
g = F.pad(g, (0, pad))
total_len = seq_len + pad
scale = 1.0 / (query.shape[-1] ** 0.5)
query = query * scale
v_beta = value * beta.unsqueeze(-1)
k_beta = key * beta.unsqueeze(-1)
query, key, value, k_beta, v_beta = [
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1])
for x in (query, key, value, k_beta, v_beta)
]
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
mask_upper = torch.triu(
torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device),
diagonal=0)
g = g.cumsum(dim=-1)
decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril()
attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask_upper, 0)
for i in range(1, chunk_size):
row = attn[..., i, :i].clone()
sub = attn[..., :i, :i].clone()
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
value = attn @ v_beta
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
last_state = (
torch.zeros(batch, num_heads, k_dim, v_dim, dtype=value.dtype, device=value.device)
if initial_state is None
else initial_state.to(value)
)
core_out = torch.zeros_like(value)
mask_upper2 = torch.triu(
torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device),
diagonal=1)
for i in range(total_len // chunk_size):
q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
attn_i = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask_upper2, 0)
v_prime = k_cumdecay[:, :, i] @ last_state
v_new = v_i - v_prime
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_state
core_out[:, :, i] = attn_inter + attn_i @ v_new
last_state = (
last_state * g[:, :, i, -1, None, None].exp()
+ (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None])
.transpose(-1, -2) @ v_new
)
if not output_final_state:
last_state = None
core_out = core_out.reshape(batch, num_heads, -1, v_dim)[:, :, :seq_len]
core_out = core_out.transpose(1, 2).contiguous().to(initial_dtype)
return core_out, last_state
def _torch_recurrent_gated_delta_rule(
query: torch.Tensor, # (batch, 1, num_heads, head_k_dim)
key: torch.Tensor,
value: torch.Tensor,
g: torch.Tensor, # (batch, 1, num_heads)
beta: torch.Tensor,
initial_state: Optional[torch.Tensor] = None,
output_final_state: bool = False,
use_qk_l2norm_in_kernel: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
initial_dtype = query.dtype
if use_qk_l2norm_in_kernel:
query = _l2norm(query)
key = _l2norm(key)
query, key, value, beta, g = [
x.transpose(1, 2).contiguous().to(torch.float32)
for x in (query, key, value, beta, g)
]
batch, num_heads, seq_len, k_dim = key.shape
v_dim = value.shape[-1]
scale = 1.0 / (query.shape[-1] ** 0.5)
query = query * scale
core_out = torch.zeros(batch, num_heads, seq_len, v_dim,
dtype=value.dtype, device=value.device)
last_state = (
torch.zeros(batch, num_heads, k_dim, v_dim,
dtype=value.dtype, device=value.device)
if initial_state is None
else initial_state.to(value)
)
for t in range(seq_len):
q_t = query[:, :, t]
k_t = key[:, :, t]
v_t = value[:, :, t]
g_t = g[:, :, t].exp().unsqueeze(-1).unsqueeze(-1)
beta_t = beta[:, :, t].unsqueeze(-1)
last_state = last_state * g_t
kv_mem = (last_state * k_t.unsqueeze(-1)).sum(dim=-2)
delta = (v_t - kv_mem) * beta_t
last_state = last_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
core_out[:, :, t] = (last_state * q_t.unsqueeze(-1)).sum(dim=-2)
if not output_final_state:
last_state = None
core_out = core_out.transpose(1, 2).contiguous().to(initial_dtype)
return core_out, last_state
# ---------------------------------------------------------------------------
# Gated RMSNorm (for DeltaNet output normalisation)
# ---------------------------------------------------------------------------
class Qwen3_5RMSNormGated(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.Tensor,
gate: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
hs = hidden_states.to(torch.float32)
variance = hs.pow(2).mean(-1, keepdim=True)
hs = hs * torch.rsqrt(variance + self.variance_epsilon)
hs = self.weight * hs.to(input_dtype)
return (hs * F.silu(gate.to(torch.float32))).to(input_dtype)
# ---------------------------------------------------------------------------
# Gated DeltaNet (linear_attention layers)
# ---------------------------------------------------------------------------
class GatedDeltaNet(nn.Module):
def __init__(
self,
text_cfg,
layer_idx: int,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.layer_idx = layer_idx
self.hidden_size = text_cfg.hidden_size
self.num_v_heads = text_cfg.linear_num_value_heads # 48
self.num_k_heads = text_cfg.linear_num_key_heads # 16
self.head_k_dim = text_cfg.linear_key_head_dim # 128
self.head_v_dim = text_cfg.linear_value_head_dim # 128
self.key_dim = self.num_k_heads * self.head_k_dim # 2048
self.value_dim = self.num_v_heads * self.head_v_dim # 6144
self.conv_dim = self.key_dim * 2 + self.value_dim # 10240
self.conv_kernel_size = text_cfg.linear_conv_kernel_dim # 4
self.head_expand_ratio = self.num_v_heads // self.num_k_heads # 3
tp_size = get_tensor_model_parallel_world_size()
# Sharded projections — MergedColumnParallelLinear shards each of q/k/v
# independently so each TP rank gets [q_shard, k_shard, v_shard].
# Plain ColumnParallelLinear would shard contiguously, giving rank 0
# [q_all, k_partial] — completely wrong Q/K/V after the split below.
self.in_proj_qkv = MergedColumnParallelLinear(
self.hidden_size, [self.key_dim, self.key_dim, self.value_dim],
bias=False, quant_config=quant_config)
self.in_proj_z = ColumnParallelLinear(
self.hidden_size, self.value_dim,
bias=False, quant_config=quant_config)
self.in_proj_b = ColumnParallelLinear(
self.hidden_size, self.num_v_heads,
bias=False, quant_config=quant_config)
self.in_proj_a = ColumnParallelLinear(
self.hidden_size, self.num_v_heads,
bias=False, quant_config=quant_config)
self.out_proj = RowParallelLinear(
self.value_dim, self.hidden_size,
bias=False, quant_config=quant_config)
# Depthwise conv weight — sharded along channel dim (dim 0)
local_conv_dim = self.conv_dim // tp_size
self.conv1d_weight = nn.Parameter(
torch.empty(local_conv_dim, 1, self.conv_kernel_size))
set_weight_attrs(self.conv1d_weight, {
"weight_loader": self._conv1d_weight_loader})
# Per-head scalar parameters — sharded along dim 0
local_num_v = self.num_v_heads // tp_size
self.A_log = nn.Parameter(torch.zeros(local_num_v))
self.dt_bias = nn.Parameter(torch.zeros(local_num_v))
set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)})
set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
# Gated RMSNorm on head_v_dim — replicated (head_v_dim=128 is small)
self.norm = Qwen3_5RMSNormGated(self.head_v_dim,
eps=text_cfg.rms_norm_eps)
def _conv1d_weight_loader(self, param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
# loaded_weight: (conv_dim=10240, 1, kernel) ordered as [q, k, v] channels
# Must gather channels in the same non-contiguous pattern that
# MergedColumnParallelLinear uses for in_proj_qkv, so that each rank's
# conv1d_weight[i] applies to the correct in_proj_qkv output channel.
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
key_local = self.key_dim // tp_size # 512 with TP=4
val_local = self.value_dim // tp_size # 1536 with TP=4
q_s = loaded_weight[tp_rank * key_local : (tp_rank + 1) * key_local]
k_s = loaded_weight[self.key_dim + tp_rank * key_local :
self.key_dim + (tp_rank + 1) * key_local]
v_s = loaded_weight[2 * self.key_dim + tp_rank * val_local :
2 * self.key_dim + (tp_rank + 1) * val_local]
param.data.copy_(torch.cat([q_s, k_s, v_s], dim=0))
def forward(
self,
hidden_states: torch.Tensor, # (total_tokens, hidden_size)
attn_metadata: AttentionMetadata,
conv_state: torch.Tensor, # (batch, local_conv_dim, kernel-1) in-place
temporal_state: torch.Tensor, # (batch, local_v_heads, k_dim, v_dim) in-place
) -> torch.Tensor:
tp_size = get_tensor_model_parallel_world_size()
local_key_dim = self.key_dim // tp_size
local_val_dim = self.value_dim // tp_size
local_num_v = self.num_v_heads // tp_size
local_num_k = self.num_k_heads // tp_size
local_conv_dim = self.conv_dim // tp_size
is_prefill = attn_metadata.num_prefill_tokens > 0
# Compute all projections for every token at once (batched, efficient)
mixed_qkv_all, _ = self.in_proj_qkv(hidden_states) # (total, local_conv_dim)
z_all, _ = self.in_proj_z(hidden_states) # (total, local_val_dim)
b_all, _ = self.in_proj_b(hidden_states) # (total, local_num_v)
a_all, _ = self.in_proj_a(hidden_states) # (total, local_num_v)
if is_prefill:
seq_starts = attn_metadata.query_start_loc.tolist()
outputs = []
state_len = self.conv_kernel_size - 1
weight_2d = self.conv1d_weight.squeeze(1) # (local_conv_dim, kernel)
for si in range(len(seq_starts) - 1):
s, e = int(seq_starts[si]), int(seq_starts[si + 1])
seq_len = e - s
# Shape: (1, local_conv_dim, seq_len)
mixed_qkv = (mixed_qkv_all[s:e]
.transpose(0, 1).unsqueeze(0)
.to(weight_2d.dtype))
# Load prev conv state BEFORE overwriting (needed for causal conv padding).
# For first prefill of a request: mamba_cache is zeros → correct.
# For chunked prefill chunk 2+: carries last state_len tokens from prev chunk.
prev_conv = conv_state[si:si + 1].clone().to(weight_2d.dtype) # [1, local_conv_dim, state_len]
# Save conv state (last state_len positions)
if seq_len >= state_len:
conv_state[si].copy_(mixed_qkv[0, :, -state_len:])
else:
conv_state[si, :, state_len - seq_len:].copy_(
mixed_qkv[0])
conv_state[si, :, :state_len - seq_len] = 0
# Causal conv: left-pad with previous conv state (not zeros).
padded = torch.cat([prev_conv, mixed_qkv], dim=2)
mixed_qkv_conv = F.conv1d(
padded, self.conv1d_weight,
bias=None, padding=0, groups=local_conv_dim)
mixed_qkv_conv = F.silu(mixed_qkv_conv)
# (1, seq_len, local_conv_dim)
mixed_qkv_conv = mixed_qkv_conv.squeeze(0).transpose(0, 1).unsqueeze(0)
q, k, v = torch.split(
mixed_qkv_conv,
[local_key_dim, local_key_dim, local_val_dim], dim=-1)
q = q.reshape(1, seq_len, local_num_k, self.head_k_dim)
k = k.reshape(1, seq_len, local_num_k, self.head_k_dim)
v = v.reshape(1, seq_len, local_num_v, self.head_v_dim)
beta = b_all[s:e].sigmoid().unsqueeze(0) # (1, seq_len, local_num_v)
g = (-self.A_log.float().exp()
* F.softplus(a_all[s:e].float() + self.dt_bias)
).unsqueeze(0) # (1, seq_len, local_num_v)
# Expand k/q to match num_v_heads
q = q.repeat_interleave(self.head_expand_ratio, dim=2)
k = k.repeat_interleave(self.head_expand_ratio, dim=2)
# Sub-sequence chunking: call _torch_chunk_gated_delta_rule
# on _DNN_CHUNK tokens at a time to cap peak memory.
# Full 18K: tensors [1,6,282,64,64]=220 MB each → ~990 MB/call.
# With _DNN_CHUNK=4096: [1,6,64,64,64]=6 MB each → ~137 MB/call.
# State is chained via initial_state / output_final_state.
_DNN_CHUNK = 4096
cur_state = temporal_state[si:si + 1].clone()
core_out_parts = []
for sc_start in range(0, seq_len, _DNN_CHUNK):
sc_end = min(sc_start + _DNN_CHUNK, seq_len)
c_out, cur_state = _torch_chunk_gated_delta_rule(
q[:, sc_start:sc_end],
k[:, sc_start:sc_end],
v[:, sc_start:sc_end],
g[:, sc_start:sc_end],
beta[:, sc_start:sc_end],
initial_state=cur_state,
output_final_state=True,
use_qk_l2norm_in_kernel=True,
)
core_out_parts.append(c_out)
if cur_state is not None:
temporal_state[si].copy_(cur_state[0])
# [1, seq_len, num_v_heads, head_v_dim]
core_out = torch.cat(core_out_parts, dim=1)
# Gate + norm + output proj
z = z_all[s:e].reshape(seq_len, local_num_v, self.head_v_dim)
core_out = core_out.reshape(seq_len, local_num_v, self.head_v_dim)
normed = self.norm(
core_out.reshape(-1, self.head_v_dim),
z.reshape(-1, self.head_v_dim))
normed = normed.reshape(seq_len, -1)
out, _ = self.out_proj(normed)
outputs.append(out)
result = torch.cat(outputs, dim=0)
if torch.isnan(result).any():
logger.warning("NaN in prefill GatedDeltaNet layer %d (frac=%.4f), replacing with zeros",
self.layer_idx, torch.isnan(result).float().mean().item())
result = torch.nan_to_num(result, nan=0.0)
return result
else:
# Decode: one token per sequence
num_seqs = hidden_states.shape[0]
weight_2d = self.conv1d_weight.squeeze(1)
# (num_seqs, local_conv_dim, 1)
mixed_qkv = (mixed_qkv_all
.to(weight_2d.dtype)
.unsqueeze(-1))
mixed_qkv_conv = _torch_causal_conv1d_update(
mixed_qkv, conv_state, weight_2d,
bias=None, activation='silu')
# (num_seqs, local_conv_dim, 1) → (num_seqs, 1, local_conv_dim)
mixed_qkv_conv = mixed_qkv_conv.squeeze(-1).unsqueeze(1)
q, k, v = torch.split(
mixed_qkv_conv,
[local_key_dim, local_key_dim, local_val_dim], dim=-1)
q = q.reshape(num_seqs, 1, local_num_k, self.head_k_dim)
k = k.reshape(num_seqs, 1, local_num_k, self.head_k_dim)
v = v.reshape(num_seqs, 1, local_num_v, self.head_v_dim)
beta = b_all.sigmoid().unsqueeze(1) # (num_seqs, 1, local_num_v)
g = (-self.A_log.float().exp()
* F.softplus(a_all.float() + self.dt_bias)
).unsqueeze(1) # (num_seqs, 1, local_num_v)
q = q.repeat_interleave(self.head_expand_ratio, dim=2)
k = k.repeat_interleave(self.head_expand_ratio, dim=2)
# Inlined decode recurrent step (seq_len=1).
# Replaces _torch_recurrent_gated_delta_rule to avoid 5 transpose+
# contiguous+float32 copies, core_out allocation, and Python loop.
# Uses bmm/baddbmm_ to eliminate 3 large (B,H,k,v) intermediate tensors.
# temporal_state: (B, H_v, k_dim, v_dim) float32 — updated in-place.
orig_dtype = q.dtype
_scale = self.head_k_dim ** -0.5
q_t = _l2norm(q.squeeze(1)).float() * _scale # (B, H_v, k_dim)
k_t = _l2norm(k.squeeze(1)).float() # (B, H_v, k_dim)
v_t = v.squeeze(1).float() # (B, H_v, v_dim)
g_t = g.squeeze(1).float().exp_() # (B, H_v)
bt = beta.squeeze(1).float() # (B, H_v)
# Decay state in-place: (B, H_v, k_dim, v_dim) *= scalar per head
temporal_state.mul_(g_t[:, :, None, None])
# Reshape to batched-matmul layout: (B*H_v, k_dim, v_dim)
ts_flat = temporal_state.view(-1, self.head_k_dim, self.head_v_dim)
BH = ts_flat.shape[0]
# kv_mem = k_t @ temporal_state shape: (B*H_v, 1, k_dim) @ (B*H_v, k_dim, v_dim)
kv_mem = torch.bmm(
k_t.view(BH, 1, self.head_k_dim), ts_flat
).view(num_seqs, local_num_v, self.head_v_dim) # (B, H_v, v_dim)
delta = (v_t - kv_mem) * bt[:, :, None] # (B, H_v, v_dim)
# State update: temporal_state += outer(k_t, delta) fused, no intermediate
ts_flat.baddbmm_(
k_t.view(BH, self.head_k_dim, 1),
delta.view(BH, 1, self.head_v_dim),
)
# Output: core_out = q_t @ updated temporal_state
core_out = torch.bmm(
q_t.view(BH, 1, self.head_k_dim), ts_flat
).view(num_seqs, local_num_v, self.head_v_dim).to(orig_dtype)
# core_out: (B, H_v, v_dim) = (num_seqs, local_num_v, head_v_dim) already
z = z_all.reshape(num_seqs, local_num_v, self.head_v_dim)
normed = self.norm(
core_out.reshape(-1, self.head_v_dim),
z.reshape(-1, self.head_v_dim))
normed = normed.reshape(num_seqs, -1)
out, _ = self.out_proj(normed)
if torch.isnan(out).any():
logger.warning("NaN in decode GatedDeltaNet layer %d (frac=%.4f), replacing with zeros",
self.layer_idx, torch.isnan(out).float().mean().item())
out = torch.nan_to_num(out, nan=0.0)
return out
# ---------------------------------------------------------------------------
# Full Attention (with gated q — unique to Qwen3.5)
# ---------------------------------------------------------------------------
class Qwen3_5FullAttention(nn.Module):
def __init__(
self,
text_cfg,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.layer_idx = layer_idx
self.hidden_size = text_cfg.hidden_size # 5120
self.num_heads = text_cfg.num_attention_heads # 24
self.num_kv_heads = text_cfg.num_key_value_heads # 4
self.head_dim = text_cfg.head_dim # 256
self.rms_norm_eps = text_cfg.rms_norm_eps
tp_size = get_tensor_model_parallel_world_size()
self.local_num_heads = self.num_heads // tp_size
self.scaling = self.head_dim ** -0.5
# When num_kv_heads < tp_size we cannot shard KV further (would give
# fractional heads per rank). Use ReplicatedLinear so every rank holds
# all KV heads; local_num_kv_heads equals the full count.
# When num_kv_heads >= tp_size standard ColumnParallel sharding applies.
if tp_size > self.num_kv_heads:
# GQA-aware TP sharding: ixformer kernel only supports num_kv_heads=1
# per rank. With num_kv_heads=2 < tp_size=4 we cannot shard KV
# evenly, but we CAN assign each rank the ONE KV head that serves
# its Q heads:
# q_per_kv = num_heads // num_kv_heads (e.g. 16//2 = 8)
# Rank r uses KV head r * local_num_heads // q_per_kv
# e.g. ranks 0,1 → KV head 0; ranks 2,3 → KV head 1.
# We replicate all KV heads to every rank and select in forward().
self.proj_kv_heads = self.num_kv_heads # heads available from projection
self.local_num_kv_heads = 1 # heads after rank-local selection
self.q_per_kv_global = self.num_heads // self.num_kv_heads
self.k_proj = ReplicatedLinear(
self.hidden_size, self.num_kv_heads * self.head_dim,
bias=False, quant_config=quant_config)
self.v_proj = ReplicatedLinear(
self.hidden_size, self.num_kv_heads * self.head_dim,
bias=False, quant_config=quant_config)
else:
# Standard sharding: each rank gets num_kv_heads // tp_size heads.
self.local_num_kv_heads = self.num_kv_heads // tp_size
self.proj_kv_heads = self.local_num_kv_heads # already sharded
self.q_per_kv_global = None
self.k_proj = ColumnParallelLinear(
self.hidden_size, self.num_kv_heads * self.head_dim,
bias=False, quant_config=quant_config,
prefix=f"{prefix}.k_proj")
self.v_proj = ColumnParallelLinear(
self.hidden_size, self.num_kv_heads * self.head_dim,
bias=False, quant_config=quant_config,
prefix=f"{prefix}.v_proj")
self.local_q_dim = self.local_num_heads * self.head_dim
self.local_kv_dim = self.local_num_kv_heads * self.head_dim
# q_proj includes gate: output = num_heads * head_dim * 2
self.q_proj = ColumnParallelLinear(
self.hidden_size, self.num_heads * self.head_dim * 2,
bias=False, quant_config=quant_config,
prefix=f"{prefix}.q_proj")
self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim, self.hidden_size,
bias=False, quant_config=quant_config,
prefix=f"{prefix}.o_proj")
self.q_norm = GemmaRMSNorm(self.head_dim, eps=self.rms_norm_eps)
self.k_norm = GemmaRMSNorm(self.head_dim, eps=self.rms_norm_eps)
# Partial RoPE: rotary_dim = head_dim * partial_rotary_factor = 256 * 0.25 = 64
rope_params = getattr(text_cfg, "rope_parameters", {}) or {}
rope_theta = rope_params.get("rope_theta", 10_000_000)
partial_factor = rope_params.get("partial_rotary_factor", 0.25)
rotary_dim = int(self.head_dim * partial_factor)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=rotary_dim,
max_position=text_cfg.max_position_embeddings,
base=rope_theta,
)
self.attn = Attention(
self.local_num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.local_num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
total_tokens = hidden_states.shape[0]
# q_proj output includes gate (dim doubled)
qg, _ = self.q_proj(hidden_states) # (total, local_num_heads * head_dim * 2)
qg = qg.view(total_tokens, self.local_num_heads, self.head_dim * 2)
q = qg[:, :, :self.head_dim].reshape(total_tokens, -1)
gate = qg[:, :, self.head_dim:].reshape(total_tokens, -1)
k, _ = self.k_proj(hidden_states) # (total, proj_kv_heads * head_dim)
v, _ = self.v_proj(hidden_states)
# q_norm on local Q heads
q = self.q_norm.forward_cuda(
q.view(total_tokens, self.local_num_heads, self.head_dim)
.contiguous()).view(total_tokens, -1)
# GQA-aware TP: select rank-local KV head BEFORE k_norm and rope so
# that ixformer kernels always see num_kv_heads=1 (same as 27B path).
# Doing k_norm/rope on 2 KV heads (proj_kv_heads=2) triggers ixformer
# paths that can produce NaN; restricting to 1 head avoids the issue.
if self.q_per_kv_global is not None:
tp_rank = get_tensor_model_parallel_rank()
kv_idx = (tp_rank * self.local_num_heads) // self.q_per_kv_global
k = (k.view(total_tokens, self.proj_kv_heads, self.head_dim)
[:, kv_idx, :].contiguous()) # (T, head_dim) — 1 head
v = (v.view(total_tokens, self.proj_kv_heads, self.head_dim)
[:, kv_idx, :].contiguous()) # (T, head_dim) — 1 head
# k_norm on the (now always 1) rank-local KV head
k = self.k_norm.forward_cuda(
k.view(total_tokens, self.local_num_kv_heads, self.head_dim)
.contiguous()).view(total_tokens, -1)
# rope: q=(T, local_num_heads*head_dim), k=(T, 1*head_dim) — mirrors 27B
q, k = self.rotary_emb(positions, q, k)
attn_out = self.attn(q, k, v, kv_cache, attn_metadata)
# Multiply by sigmoid gate before output projection
attn_out = attn_out * torch.sigmoid(gate.float()).to(attn_out.dtype)
output, _ = self.o_proj(attn_out)
return output
# ---------------------------------------------------------------------------
# MLP (SwiGLU, same as Qwen2/Qwen3)
# ---------------------------------------------------------------------------
class Qwen3_5MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False, quant_config=quant_config)
self.down_proj = RowParallelLinear(
intermediate_size, hidden_size,
bias=False, quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}")
self.act_fn = SiluAndMul()
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
# ---------------------------------------------------------------------------
# MoE sparse block (Qwen3.5-MoE / Qwen3.6-35B-A3B)
# ---------------------------------------------------------------------------
class Qwen3_5MoeSparseBlock(nn.Module):
"""Replaces Qwen3_5MLP for qwen3_5_moe_text layers.
FusedMoE is used ONLY for weight storage and loading (create_weights /
weight_loader are pure PyTorch). Its forward kernel is bypassed because
ixformer on BI-V100 lacks vllm_moe_topk_softmax / vllm_invoke_fused_moe_kernel.
Routing and expert computation use a pure-PyTorch loop instead.
Shared expert uses RowParallelLinear(reduce_results=False) so both paths
produce partial (pre-all-reduce) outputs that are combined before a single
all-reduce.
"""
def __init__(
self,
text_cfg,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
hidden_size = text_cfg.hidden_size
self.num_experts = text_cfg.num_experts
self.top_k = text_cfg.num_experts_per_tok
# Router: replicated (small: num_experts outputs)
self.gate = ReplicatedLinear(hidden_size, text_cfg.num_experts,
bias=False, quant_config=quant_config)
# FusedMoE: only used for weight storage + weight_loader.
# Forward is bypassed — see _pure_pytorch_experts().
self.experts = FusedMoE(
num_experts=text_cfg.num_experts,
top_k=text_cfg.num_experts_per_tok,
hidden_size=hidden_size,
intermediate_size=text_cfg.moe_intermediate_size,
reduce_results=False, # we do the all-reduce ourselves below
renormalize=True,
quant_config=quant_config,
)
# Shared expert: defer all-reduce to combine with routed output first
shared_size = text_cfg.shared_expert_intermediate_size
self.shared_expert_gate_up = MergedColumnParallelLinear(
hidden_size, [shared_size] * 2, bias=False,
quant_config=quant_config)
self.shared_expert_down = RowParallelLinear(
shared_size, hidden_size, bias=False, reduce_results=False,
quant_config=quant_config)
self.act_fn = SiluAndMul()
# Scalar sigmoid gate on shared expert output (same as Qwen2-MoE / Qwen3.5-MoE):
# shared_out *= sigmoid(shared_expert_gate(hidden_states))
# Without this, shared expert is always fully active → wrong logits.
self.shared_expert_gate = ReplicatedLinear(
hidden_size, 1, bias=False, quant_config=quant_config)
def _pure_pytorch_experts(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor:
"""Pure-PyTorch MoE (ixformer has no MoE kernels on BI-V100).
w13_weight: (num_experts, 2*inter_per_partition, hidden) [TP-sharded]
w2_weight: (num_experts, hidden, inter_per_partition) [TP-sharded]
Output is partial (pre-all-reduce), same contract as FusedMoE
with reduce_results=False.
"""
# Routing: softmax → topk → renormalise
routing_weights = torch.softmax(router_logits.float(), dim=-1)
topk_weights, topk_ids = torch.topk(
routing_weights, self.top_k, dim=-1) # (T, top_k)
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights = topk_weights.to(hidden_states.dtype)
w13 = self.experts.w13_weight # (E, 2*I, H)
w2 = self.experts.w2_weight # (E, H, I)
T = hidden_states.shape[0]
if T == 1:
# Fast path: single token (decode).
# Batched GEMM: replace top_k separate F.linear calls with 2 fused ops.
# gate_up: 1 large GEMM (1,H) × (K*2*I,H)^T → (1, K*2*I)
# down: 1 bmm (K,H,I) @ (K,I,1) → (K,H)
# Total: 3 kernel launches vs previous 16 (top_k*2).
eids = topk_ids[0] # (K,)
ws = topk_weights[0].to(hidden_states.dtype) # (K,)
w13_sel = w13[eids] # (K, 2*I, H)
w2_sel = w2[eids] # (K, H, I)
H = hidden_states.shape[-1]
gate_up = F.linear(
hidden_states,
w13_sel.reshape(-1, H), # (K*2*I, H) — contiguous after indexing
) # (1, K*2*I)
gate_up = gate_up.view(self.top_k, -1) # (K, 2*I)
gate, up = gate_up.chunk(2, dim=-1) # (K, I) each
act = F.silu(gate) * up # (K, I)
# bmm: (K,H,I) @ (K,I,1) → (K,H,1) → (K,H)
expert_out = torch.bmm(w2_sel, act.unsqueeze(-1)).squeeze(-1) # (K, H)
out = (expert_out * ws.unsqueeze(-1)).sum(0, keepdim=True).to(
hidden_states.dtype) # (1, H)
else:
# General path (prefill / multi-seq): loop over unique active experts.
# At most T*top_k unique experts, always <= num_experts.
out = torch.zeros_like(hidden_states)
unique_eids = topk_ids.view(-1).unique().tolist()
for eid in unique_eids:
eid = int(eid)
mask = (topk_ids == eid) # (T, top_k)
tok_ids, topk_pos = mask.nonzero(as_tuple=True)
tokens = hidden_states[tok_ids] # (n, H)
gate_up = F.linear(tokens, w13[eid]) # (n, 2*I)
gate, up = gate_up.chunk(2, dim=-1)
act = F.silu(gate) * up # (n, I)
expert_out = F.linear(act, w2[eid]) # (n, H)
weights = topk_weights[tok_ids, topk_pos].unsqueeze(-1)
out.index_add_(0, tok_ids, (expert_out * weights).to(out.dtype))
return out # partial, all-reduce done in forward()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
router_logits, _ = self.gate(hidden_states)
routed_out = self._pure_pytorch_experts(hidden_states, router_logits)
gate_up, _ = self.shared_expert_gate_up(hidden_states)
shared_out = self.act_fn(gate_up)
shared_out, _ = self.shared_expert_down(shared_out)
# Scalar sigmoid gate (Qwen2-MoE / Qwen3.5-MoE style)
gate_score, _ = self.shared_expert_gate(hidden_states) # (T, 1)
shared_out = shared_out * torch.sigmoid(gate_score)
out = routed_out + shared_out
if self.experts.tp_size > 1:
out = tensor_model_parallel_all_reduce(out)
return out
# ---------------------------------------------------------------------------
# Decoder layer (dispatches to GatedDeltaNet or Qwen3_5FullAttention)
# ---------------------------------------------------------------------------
class Qwen3_5DecoderLayer(nn.Module):
def __init__(
self,
text_cfg,
layer_idx: int,
layer_type: str,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.layer_idx = layer_idx
self.layer_type = layer_type
self.input_layernorm = GemmaRMSNorm(text_cfg.hidden_size,
eps=text_cfg.rms_norm_eps)
self.post_attention_layernorm = GemmaRMSNorm(text_cfg.hidden_size,
eps=text_cfg.rms_norm_eps)
if layer_type == "linear_attention":
self.linear_attn = GatedDeltaNet(text_cfg, layer_idx,
quant_config=quant_config)
else:
self.self_attn = Qwen3_5FullAttention(
text_cfg, layer_idx,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"layers.{layer_idx}.self_attn",
)
if getattr(text_cfg, 'model_type', '') == 'qwen3_5_moe_text':
self.mlp = Qwen3_5MoeSparseBlock(text_cfg, quant_config=quant_config)
else:
self.mlp = Qwen3_5MLP(
hidden_size=text_cfg.hidden_size,
intermediate_size=text_cfg.intermediate_size,
hidden_act=text_cfg.hidden_act,
quant_config=quant_config,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
# Only for linear_attention layers:
conv_state: Optional[torch.Tensor] = None,
temporal_state: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
if self.layer_type == "linear_attention":
hidden_states = self.linear_attn(
hidden_states, attn_metadata, conv_state, temporal_state)
else:
hidden_states = self.self_attn(
positions, hidden_states, kv_cache, attn_metadata)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
# ---------------------------------------------------------------------------
# Full transformer model
# ---------------------------------------------------------------------------
class Qwen3_5Model(nn.Module):
def __init__(
self,
text_cfg,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.text_cfg = text_cfg
self.embed_tokens = VocabParallelEmbedding(
text_cfg.vocab_size, text_cfg.hidden_size)
self.layers = nn.ModuleList([
Qwen3_5DecoderLayer(
text_cfg, i, text_cfg.layer_types[i],
cache_config=cache_config, quant_config=quant_config)
for i in range(text_cfg.num_hidden_layers)
])
self.norm = GemmaRMSNorm(text_cfg.hidden_size, eps=text_cfg.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
conv_states: torch.Tensor, # (num_linear_layers, batch, ...)
temporal_states: torch.Tensor, # (num_linear_layers, batch, ...)
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
attn_idx = 0
linear_idx = 0
for layer in self.layers:
if layer.layer_type == "linear_attention":
hidden_states, residual = layer(
positions, hidden_states,
kv_cache=None,
attn_metadata=attn_metadata,
residual=residual,
conv_state=conv_states[linear_idx],
temporal_state=temporal_states[linear_idx],
)
linear_idx += 1
else:
kv_cache = kv_caches[attn_idx]
hidden_states, residual = layer(
positions, hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
residual=residual,
)
attn_idx += 1
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
# ---------------------------------------------------------------------------
# Top-level CausalLM wrapper with MambaCacheManager
# ---------------------------------------------------------------------------
class Qwen3_5ForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
has_inner_state = True
supports_lora = True
packed_modules_mapping = {
"gate_up_proj": ["gate_proj", "up_proj"],
}
supported_lora_modules = [
"gate_up_proj",
"down_proj",
"o_proj",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__(
self,
config, # Qwen3_5Config (top-level)
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
scheduler_config: Optional[SchedulerConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.scheduler_config = scheduler_config
# The text config holds all architecture parameters
text_cfg = config.text_config
self.text_cfg = text_cfg
# Pre-compute counts
self.num_linear_layers = sum(
1 for lt in text_cfg.layer_types if lt == "linear_attention")
self.num_attn_layers = sum(
1 for lt in text_cfg.layer_types if lt == "full_attention")
# DeltaNet state dimensions (per layer, per sequence, TP-sharded)
tp_size = get_tensor_model_parallel_world_size()
self.conv_dim = (text_cfg.linear_num_key_heads * text_cfg.linear_key_head_dim * 2
+ text_cfg.linear_num_value_heads * text_cfg.linear_value_head_dim)
self.num_v_heads = text_cfg.linear_num_value_heads
self.head_k_dim = text_cfg.linear_key_head_dim
self.head_v_dim = text_cfg.linear_value_head_dim
self.conv_kernel_size = text_cfg.linear_conv_kernel_dim
self.model = Qwen3_5Model(
text_cfg,
cache_config=cache_config,
quant_config=quant_config,
)
self.lm_head = ParallelLMHead(
text_cfg.vocab_size, text_cfg.hidden_size,
quant_config=quant_config,
)
self.logits_processor = LogitsProcessor(text_cfg.vocab_size)
self.sampler = Sampler()
# Lazy initialised in first forward call
self.mamba_cache: Optional[MambaCacheManager] = None
def _get_mamba_cache_shape(self):
tp_size = get_tensor_model_parallel_world_size()
# Each sequence's state is stored in float32
conv_state_shape = (self.conv_dim // tp_size, self.conv_kernel_size - 1)
temporal_state_shape = (
self.num_v_heads // tp_size, self.head_k_dim, self.head_v_dim)
return conv_state_shape, temporal_state_shape
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs,
) -> torch.Tensor:
if self.mamba_cache is None:
if self.scheduler_config is not None:
max_batch_size = _get_graph_batch_size(
self.scheduler_config.max_num_seqs)
else:
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) + 2
self.mamba_cache = MambaCacheManager(
torch.float32,
self.num_linear_layers,
max_batch_size,
*self._get_mamba_cache_shape(),
)
mamba_tensors = self.mamba_cache.current_run_tensors(
input_ids, attn_metadata, **kwargs)
# conv_states: (num_linear_layers, batch, local_conv_dim, kernel-1)
# temporal_states: (num_linear_layers, batch, local_num_v, k_dim, v_dim)
conv_states, temporal_states = mamba_tensors
hidden_states = self.model(
input_ids, positions, kv_caches, attn_metadata,
conv_states, temporal_states)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
# All TP ranks must call logits_processor to participate in the NCCL
# gather inside lm_head. Non-driver ranks return None after the gather.
# With chunked prefill, intermediate chunks have seq_groups=None on all
# ranks; _apply_logits_processors is guarded against this in
# logits_processor.py (patched by patch_xformers_sdpa_seq.py).
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.sampler(logits, sampling_metadata)
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, weight_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
# Skip vision and MTP branches
if (name.startswith("model.visual")
or name.startswith("mtp.")
or name.startswith("model.mtp")):
continue
# Prefix remapping: checkpoint may wrap under language_model
if name.startswith("model.language_model."):
name = "model." + name[len("model.language_model."):]
# Skip positional embedding caches
if "rotary_emb.inv_freq" in name:
continue
# Remap conv1d.weight → conv1d_weight
# The conv has depth (1) dim in the checkpoint that we handle separately
if ".linear_attn.conv1d.weight" in name:
name = name.replace(".linear_attn.conv1d.weight",
".linear_attn.conv1d_weight")
# Stacked param loading (gate_up_proj)
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if name.endswith(".bias") and name not in params_dict:
break
if name not in params_dict:
break
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# ---------------------------------------------------------------------------
# Qwen3.6-35B-A3B (Qwen3_5-MoE architecture)
# ---------------------------------------------------------------------------
class Qwen3_5MoeForCausalLM(Qwen3_5ForCausalLM):
"""Qwen3.6-35B-A3B: same hybrid-attention backbone as 27B, dense MLP
replaced by Qwen3_5MoeSparseBlock (256 routed experts + shared expert).
Only load_weights differs from the dense variant.
"""
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Checkpoint key format for this model (transformers Qwen3_5MoeExperts):
# mlp.experts.gate_up_proj shape (num_experts, 2*intermediate, hidden)
# mlp.experts.down_proj shape (num_experts, hidden, intermediate)
# mlp.gate.weight shape (num_experts, hidden) [router]
# mlp.shared_expert.{gate,up,down}_proj.weight [shared MLP]
# Our FusedMoE stores:
# mlp.experts.w13_weight shape (num_experts, 2*intermediate//tp, hidden)
# mlp.experts.w2_weight shape (num_experts, hidden, intermediate//tp)
# Our shared expert stores:
# mlp.shared_expert_gate_up.weight (merged gate+up)
# mlp.shared_expert_down.weight
stacked_params_mapping = [
# (param_name, weight_name, shard_id)
# shared expert
("shared_expert_gate_up", "shared_expert.gate_proj", 0),
("shared_expert_gate_up", "shared_expert.up_proj", 1),
# linear_attention dense proj (same as 27B)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
# Skip vision and MTP branches
if (name.startswith("model.visual")
or name.startswith("mtp.")
or name.startswith("model.mtp")):
continue
# Prefix remapping for VL checkpoint (Qwen3_5MoeForConditionalGeneration):
# model.language_model.model.{layers,embed_tokens,norm} -> model.{...}
# model.language_model.lm_head -> lm_head
# Prefix remapping: checkpoint may wrap under language_model
if name.startswith("model.language_model."):
name = "model." + name[len("model.language_model."):]
if "rotary_emb.inv_freq" in name:
continue
if ".linear_attn.conv1d.weight" in name:
name = name.replace(".linear_attn.conv1d.weight",
".linear_attn.conv1d_weight")
# --- Fused routed-expert weights (all experts in one tensor) ---
if "mlp.experts.gate_up_proj" in name:
# loaded_weight: (num_experts, 2*intermediate, hidden)
w13_name = name.replace("mlp.experts.gate_up_proj",
"mlp.experts.w13_weight")
if w13_name not in params_dict:
continue
param = params_dict[w13_name]
n_exp = loaded_weight.shape[0]
inter = loaded_weight.shape[1] // 2
gate_w = loaded_weight[:, :inter, :].contiguous()
up_w = loaded_weight[:, inter:, :].contiguous()
for eid in range(n_exp):
param.weight_loader(param, gate_w[eid], "w1_weight", "w1", eid)
param.weight_loader(param, up_w[eid], "w3_weight", "w3", eid)
continue
if "mlp.experts.down_proj" in name:
# loaded_weight: (num_experts, hidden, intermediate)
w2_name = name.replace("mlp.experts.down_proj",
"mlp.experts.w2_weight")
if w2_name not in params_dict:
continue
param = params_dict[w2_name]
n_exp = loaded_weight.shape[0]
for eid in range(n_exp):
param.weight_loader(param, loaded_weight[eid], "w2_weight", "w2", eid)
continue
# --- Shared expert down_proj rename ---
if "mlp.shared_expert.down_proj" in name:
name = name.replace("mlp.shared_expert.down_proj",
"mlp.shared_expert_down")
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
continue
# --- Stacked / standard weights ---
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if name not in params_dict:
break
param = params_dict[name]
param.weight_loader(param, loaded_weight, shard_id)
break
else:
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)