initial commit for qwen3.6-moe adaptation

This commit is contained in:
2026-06-12 10:10:49 +08:00
parent 365da18436
commit 629f878c28
6 changed files with 560 additions and 49 deletions

View File

@@ -11,12 +11,15 @@ from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
@@ -495,24 +498,52 @@ class Qwen3_5FullAttention(nn.Module):
tp_size = get_tensor_model_parallel_world_size()
self.local_num_heads = self.num_heads // tp_size
self.local_num_kv_heads = max(1, self.num_kv_heads // tp_size)
self.scaling = self.head_dim ** -0.5
# When num_kv_heads < tp_size we cannot shard KV further (would give
# fractional heads per rank). Use ReplicatedLinear so every rank holds
# all KV heads; local_num_kv_heads equals the full count.
# When num_kv_heads >= tp_size standard ColumnParallel sharding applies.
if tp_size > self.num_kv_heads:
# GQA-aware TP sharding: ixformer kernel only supports num_kv_heads=1
# per rank. With num_kv_heads=2 < tp_size=4 we cannot shard KV
# evenly, but we CAN assign each rank the ONE KV head that serves
# its Q heads:
# q_per_kv = num_heads // num_kv_heads (e.g. 16//2 = 8)
# Rank r uses KV head r * local_num_heads // q_per_kv
# e.g. ranks 0,1 → KV head 0; ranks 2,3 → KV head 1.
# We replicate all KV heads to every rank and select in forward().
self.proj_kv_heads = self.num_kv_heads # heads available from projection
self.local_num_kv_heads = 1 # heads after rank-local selection
self.q_per_kv_global = self.num_heads // self.num_kv_heads
self.k_proj = ReplicatedLinear(
self.hidden_size, self.num_kv_heads * self.head_dim,
bias=False, quant_config=quant_config)
self.v_proj = ReplicatedLinear(
self.hidden_size, self.num_kv_heads * self.head_dim,
bias=False, quant_config=quant_config)
else:
# Standard sharding: each rank gets num_kv_heads // tp_size heads.
self.local_num_kv_heads = self.num_kv_heads // tp_size
self.proj_kv_heads = self.local_num_kv_heads # already sharded
self.q_per_kv_global = None
self.k_proj = ColumnParallelLinear(
self.hidden_size, self.num_kv_heads * self.head_dim,
bias=False, quant_config=quant_config,
prefix=f"{prefix}.k_proj")
self.v_proj = ColumnParallelLinear(
self.hidden_size, self.num_kv_heads * self.head_dim,
bias=False, quant_config=quant_config,
prefix=f"{prefix}.v_proj")
self.local_q_dim = self.local_num_heads * self.head_dim
self.local_kv_dim = self.local_num_kv_heads * self.head_dim
self.scaling = self.head_dim ** -0.5
# q_proj includes gate: output = num_heads * head_dim * 2
self.q_proj = ColumnParallelLinear(
self.hidden_size, self.num_heads * self.head_dim * 2,
bias=False, quant_config=quant_config,
prefix=f"{prefix}.q_proj")
self.k_proj = ColumnParallelLinear(
self.hidden_size, self.num_kv_heads * self.head_dim,
bias=False, quant_config=quant_config,
prefix=f"{prefix}.k_proj")
self.v_proj = ColumnParallelLinear(
self.hidden_size, self.num_kv_heads * self.head_dim,
bias=False, quant_config=quant_config,
prefix=f"{prefix}.v_proj")
self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim, self.hidden_size,
bias=False, quant_config=quant_config,
@@ -559,18 +590,34 @@ class Qwen3_5FullAttention(nn.Module):
q = qg[:, :, :self.head_dim].reshape(total_tokens, -1)
gate = qg[:, :, self.head_dim:].reshape(total_tokens, -1)
k, _ = self.k_proj(hidden_states) # (total, local_kv_dim)
k, _ = self.k_proj(hidden_states) # (total, proj_kv_heads * head_dim)
v, _ = self.v_proj(hidden_states)
# Per-head RMSNorm
# q_norm on local Q heads
q = self.q_norm.forward_cuda(
q.view(total_tokens, self.local_num_heads, self.head_dim)
.contiguous()).view(total_tokens, -1)
# GQA-aware TP: select rank-local KV head BEFORE k_norm and rope so
# that ixformer kernels always see num_kv_heads=1 (same as 27B path).
# Doing k_norm/rope on 2 KV heads (proj_kv_heads=2) triggers ixformer
# paths that can produce NaN; restricting to 1 head avoids the issue.
if self.q_per_kv_global is not None:
tp_rank = get_tensor_model_parallel_rank()
kv_idx = (tp_rank * self.local_num_heads) // self.q_per_kv_global
k = (k.view(total_tokens, self.proj_kv_heads, self.head_dim)
[:, kv_idx, :].contiguous()) # (T, head_dim) — 1 head
v = (v.view(total_tokens, self.proj_kv_heads, self.head_dim)
[:, kv_idx, :].contiguous()) # (T, head_dim) — 1 head
# k_norm on the (now always 1) rank-local KV head
k = self.k_norm.forward_cuda(
k.view(total_tokens, self.local_num_kv_heads, self.head_dim)
.contiguous()).view(total_tokens, -1)
# rope: q=(T, local_num_heads*head_dim), k=(T, 1*head_dim) — mirrors 27B
q, k = self.rotary_emb(positions, q, k)
attn_out = self.attn(q, k, v, kv_cache, attn_metadata)
# Multiply by sigmoid gate before output projection
@@ -609,10 +656,130 @@ class Qwen3_5MLP(nn.Module):
return x
# ---------------------------------------------------------------------------
# MoE sparse block (Qwen3.5-MoE / Qwen3.6-35B-A3B)
# ---------------------------------------------------------------------------
class Qwen3_5MoeSparseBlock(nn.Module):
"""Replaces Qwen3_5MLP for qwen3_5_moe_text layers.
FusedMoE is used ONLY for weight storage and loading (create_weights /
weight_loader are pure PyTorch). Its forward kernel is bypassed because
ixformer on BI-V100 lacks vllm_moe_topk_softmax / vllm_invoke_fused_moe_kernel.
Routing and expert computation use a pure-PyTorch loop instead.
Shared expert uses RowParallelLinear(reduce_results=False) so both paths
produce partial (pre-all-reduce) outputs that are combined before a single
all-reduce.
"""
def __init__(
self,
text_cfg,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
hidden_size = text_cfg.hidden_size
self.num_experts = text_cfg.num_experts
self.top_k = text_cfg.num_experts_per_tok
# Router: replicated (small: num_experts outputs)
self.gate = ReplicatedLinear(hidden_size, text_cfg.num_experts,
bias=False, quant_config=quant_config)
# FusedMoE: only used for weight storage + weight_loader.
# Forward is bypassed — see _pure_pytorch_experts().
self.experts = FusedMoE(
num_experts=text_cfg.num_experts,
top_k=text_cfg.num_experts_per_tok,
hidden_size=hidden_size,
intermediate_size=text_cfg.moe_intermediate_size,
reduce_results=False, # we do the all-reduce ourselves below
renormalize=True,
quant_config=quant_config,
)
# Shared expert: defer all-reduce to combine with routed output first
shared_size = text_cfg.shared_expert_intermediate_size
self.shared_expert_gate_up = MergedColumnParallelLinear(
hidden_size, [shared_size] * 2, bias=False,
quant_config=quant_config)
self.shared_expert_down = RowParallelLinear(
shared_size, hidden_size, bias=False, reduce_results=False,
quant_config=quant_config)
self.act_fn = SiluAndMul()
# Scalar sigmoid gate on shared expert output (same as Qwen2-MoE / Qwen3.5-MoE):
# shared_out *= sigmoid(shared_expert_gate(hidden_states))
# Without this, shared expert is always fully active → wrong logits.
self.shared_expert_gate = ReplicatedLinear(
hidden_size, 1, bias=False, quant_config=quant_config)
def _pure_pytorch_experts(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor:
"""Pure-PyTorch MoE (ixformer has no MoE kernels on BI-V100).
w13_weight: (num_experts, 2*inter_per_partition, hidden) [TP-sharded]
w2_weight: (num_experts, hidden, inter_per_partition) [TP-sharded]
Output is partial (pre-all-reduce), same contract as FusedMoE
with reduce_results=False.
"""
# Routing: softmax → topk → renormalise
routing_weights = torch.softmax(router_logits.float(), dim=-1)
topk_weights, topk_ids = torch.topk(
routing_weights, self.top_k, dim=-1) # (T, top_k)
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights = topk_weights.to(hidden_states.dtype)
out = torch.zeros_like(hidden_states)
w13 = self.experts.w13_weight # (E, 2*I, H)
w2 = self.experts.w2_weight # (E, H, I)
for eid in range(self.num_experts):
# Tokens routed to this expert
mask = (topk_ids == eid) # (T, top_k) bool
tok_ids, topk_pos = mask.nonzero(as_tuple=True)
if tok_ids.numel() == 0:
continue
tokens = hidden_states[tok_ids] # (n, H)
# gate + up projection (ColumnParallel shard)
gate_up = F.linear(tokens, w13[eid]) # (n, 2*I)
gate, up = gate_up.chunk(2, dim=-1)
act = F.silu(gate) * up # (n, I)
# down projection (RowParallel shard) — result is partial
# F.linear(x, W) = x @ W.T; w2[eid]: (H, I) → x @ W.T = (n,H) ✓
expert_out = F.linear(act, w2[eid]) # (n, H)
weights = topk_weights[tok_ids, topk_pos].unsqueeze(-1)
out.index_add_(0, tok_ids, (expert_out * weights).to(out.dtype))
return out # partial, all-reduce done in forward()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
router_logits, _ = self.gate(hidden_states)
routed_out = self._pure_pytorch_experts(hidden_states, router_logits)
gate_up, _ = self.shared_expert_gate_up(hidden_states)
shared_out = self.act_fn(gate_up)
shared_out, _ = self.shared_expert_down(shared_out)
# Scalar sigmoid gate (Qwen2-MoE / Qwen3.5-MoE style)
gate_score, _ = self.shared_expert_gate(hidden_states) # (T, 1)
shared_out = shared_out * torch.sigmoid(gate_score)
out = routed_out + shared_out
if self.experts.tp_size > 1:
out = tensor_model_parallel_all_reduce(out)
return out
# ---------------------------------------------------------------------------
# Decoder layer (dispatches to GatedDeltaNet or Qwen3_5FullAttention)
# ---------------------------------------------------------------------------
class Qwen3_5DecoderLayer(nn.Module):
def __init__(
self,
@@ -623,6 +790,7 @@ class Qwen3_5DecoderLayer(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.layer_idx = layer_idx
self.layer_type = layer_type
self.input_layernorm = GemmaRMSNorm(text_cfg.hidden_size,
eps=text_cfg.rms_norm_eps)
@@ -640,12 +808,15 @@ class Qwen3_5DecoderLayer(nn.Module):
prefix=f"layers.{layer_idx}.self_attn",
)
self.mlp = Qwen3_5MLP(
hidden_size=text_cfg.hidden_size,
intermediate_size=text_cfg.intermediate_size,
hidden_act=text_cfg.hidden_act,
quant_config=quant_config,
)
if getattr(text_cfg, 'model_type', '') == 'qwen3_5_moe_text':
self.mlp = Qwen3_5MoeSparseBlock(text_cfg, quant_config=quant_config)
else:
self.mlp = Qwen3_5MLP(
hidden_size=text_cfg.hidden_size,
intermediate_size=text_cfg.intermediate_size,
hidden_act=text_cfg.hidden_act,
quant_config=quant_config,
)
def forward(
self,
@@ -673,7 +844,9 @@ class Qwen3_5DecoderLayer(nn.Module):
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@@ -860,8 +1033,9 @@ class Qwen3_5ForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
# With chunked prefill, intermediate chunks have seq_groups=None on all
# ranks; _apply_logits_processors is guarded against this in
# logits_processor.py (patched by patch_xformers_sdpa_seq.py).
return self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
@@ -892,12 +1066,9 @@ class Qwen3_5ForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
or name.startswith("model.mtp")):
continue
# Remap checkpoint prefix → module path
# Checkpoint: "model.language_model.{rest}" → our module: "model.{rest}"
# Checkpoint: "lm_head.weight" → our module: "lm_head.weight"
# Prefix remapping: checkpoint may wrap under language_model
if name.startswith("model.language_model."):
name = "model." + name[len("model.language_model."):]
# lm_head is already at top level — no change needed
# Skip positional embedding caches
if "rotary_emb.inv_freq" in name:
@@ -931,3 +1102,118 @@ class Qwen3_5ForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# ---------------------------------------------------------------------------
# Qwen3.6-35B-A3B (Qwen3_5-MoE architecture)
# ---------------------------------------------------------------------------
class Qwen3_5MoeForCausalLM(Qwen3_5ForCausalLM):
"""Qwen3.6-35B-A3B: same hybrid-attention backbone as 27B, dense MLP
replaced by Qwen3_5MoeSparseBlock (256 routed experts + shared expert).
Only load_weights differs from the dense variant.
"""
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Checkpoint key format for this model (transformers Qwen3_5MoeExperts):
# mlp.experts.gate_up_proj shape (num_experts, 2*intermediate, hidden)
# mlp.experts.down_proj shape (num_experts, hidden, intermediate)
# mlp.gate.weight shape (num_experts, hidden) [router]
# mlp.shared_expert.{gate,up,down}_proj.weight [shared MLP]
# Our FusedMoE stores:
# mlp.experts.w13_weight shape (num_experts, 2*intermediate//tp, hidden)
# mlp.experts.w2_weight shape (num_experts, hidden, intermediate//tp)
# Our shared expert stores:
# mlp.shared_expert_gate_up.weight (merged gate+up)
# mlp.shared_expert_down.weight
stacked_params_mapping = [
# (param_name, weight_name, shard_id)
# shared expert
("shared_expert_gate_up", "shared_expert.gate_proj", 0),
("shared_expert_gate_up", "shared_expert.up_proj", 1),
# linear_attention dense proj (same as 27B)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
# Skip vision and MTP branches
if (name.startswith("model.visual")
or name.startswith("mtp.")
or name.startswith("model.mtp")):
continue
# Prefix remapping for VL checkpoint (Qwen3_5MoeForConditionalGeneration):
# model.language_model.model.{layers,embed_tokens,norm} -> model.{...}
# model.language_model.lm_head -> lm_head
# Prefix remapping: checkpoint may wrap under language_model
if name.startswith("model.language_model."):
name = "model." + name[len("model.language_model."):]
if "rotary_emb.inv_freq" in name:
continue
if ".linear_attn.conv1d.weight" in name:
name = name.replace(".linear_attn.conv1d.weight",
".linear_attn.conv1d_weight")
# --- Fused routed-expert weights (all experts in one tensor) ---
if "mlp.experts.gate_up_proj" in name:
# loaded_weight: (num_experts, 2*intermediate, hidden)
w13_name = name.replace("mlp.experts.gate_up_proj",
"mlp.experts.w13_weight")
if w13_name not in params_dict:
continue
param = params_dict[w13_name]
n_exp = loaded_weight.shape[0]
inter = loaded_weight.shape[1] // 2
gate_w = loaded_weight[:, :inter, :].contiguous()
up_w = loaded_weight[:, inter:, :].contiguous()
for eid in range(n_exp):
param.weight_loader(param, gate_w[eid], "w1_weight", "w1", eid)
param.weight_loader(param, up_w[eid], "w3_weight", "w3", eid)
continue
if "mlp.experts.down_proj" in name:
# loaded_weight: (num_experts, hidden, intermediate)
w2_name = name.replace("mlp.experts.down_proj",
"mlp.experts.w2_weight")
if w2_name not in params_dict:
continue
param = params_dict[w2_name]
n_exp = loaded_weight.shape[0]
for eid in range(n_exp):
param.weight_loader(param, loaded_weight[eid], "w2_weight", "w2", eid)
continue
# --- Shared expert down_proj rename ---
if "mlp.shared_expert.down_proj" in name:
name = name.replace("mlp.shared_expert.down_proj",
"mlp.shared_expert_down")
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
continue
# --- Stacked / standard weights ---
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if name not in params_dict:
break
param = params_dict[name]
param.weight_loader(param, loaded_weight, shard_id)
break
else:
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)