initial commit for qwen3.6-moe adaptation
This commit is contained in:
@@ -11,12 +11,15 @@ from torch import nn
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
@@ -495,24 +498,52 @@ class Qwen3_5FullAttention(nn.Module):
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.local_num_heads = self.num_heads // tp_size
|
||||
self.local_num_kv_heads = max(1, self.num_kv_heads // tp_size)
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
|
||||
# When num_kv_heads < tp_size we cannot shard KV further (would give
|
||||
# fractional heads per rank). Use ReplicatedLinear so every rank holds
|
||||
# all KV heads; local_num_kv_heads equals the full count.
|
||||
# When num_kv_heads >= tp_size standard ColumnParallel sharding applies.
|
||||
if tp_size > self.num_kv_heads:
|
||||
# GQA-aware TP sharding: ixformer kernel only supports num_kv_heads=1
|
||||
# per rank. With num_kv_heads=2 < tp_size=4 we cannot shard KV
|
||||
# evenly, but we CAN assign each rank the ONE KV head that serves
|
||||
# its Q heads:
|
||||
# q_per_kv = num_heads // num_kv_heads (e.g. 16//2 = 8)
|
||||
# Rank r uses KV head r * local_num_heads // q_per_kv
|
||||
# e.g. ranks 0,1 → KV head 0; ranks 2,3 → KV head 1.
|
||||
# We replicate all KV heads to every rank and select in forward().
|
||||
self.proj_kv_heads = self.num_kv_heads # heads available from projection
|
||||
self.local_num_kv_heads = 1 # heads after rank-local selection
|
||||
self.q_per_kv_global = self.num_heads // self.num_kv_heads
|
||||
self.k_proj = ReplicatedLinear(
|
||||
self.hidden_size, self.num_kv_heads * self.head_dim,
|
||||
bias=False, quant_config=quant_config)
|
||||
self.v_proj = ReplicatedLinear(
|
||||
self.hidden_size, self.num_kv_heads * self.head_dim,
|
||||
bias=False, quant_config=quant_config)
|
||||
else:
|
||||
# Standard sharding: each rank gets num_kv_heads // tp_size heads.
|
||||
self.local_num_kv_heads = self.num_kv_heads // tp_size
|
||||
self.proj_kv_heads = self.local_num_kv_heads # already sharded
|
||||
self.q_per_kv_global = None
|
||||
self.k_proj = ColumnParallelLinear(
|
||||
self.hidden_size, self.num_kv_heads * self.head_dim,
|
||||
bias=False, quant_config=quant_config,
|
||||
prefix=f"{prefix}.k_proj")
|
||||
self.v_proj = ColumnParallelLinear(
|
||||
self.hidden_size, self.num_kv_heads * self.head_dim,
|
||||
bias=False, quant_config=quant_config,
|
||||
prefix=f"{prefix}.v_proj")
|
||||
|
||||
self.local_q_dim = self.local_num_heads * self.head_dim
|
||||
self.local_kv_dim = self.local_num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
|
||||
# q_proj includes gate: output = num_heads * head_dim * 2
|
||||
self.q_proj = ColumnParallelLinear(
|
||||
self.hidden_size, self.num_heads * self.head_dim * 2,
|
||||
bias=False, quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_proj")
|
||||
self.k_proj = ColumnParallelLinear(
|
||||
self.hidden_size, self.num_kv_heads * self.head_dim,
|
||||
bias=False, quant_config=quant_config,
|
||||
prefix=f"{prefix}.k_proj")
|
||||
self.v_proj = ColumnParallelLinear(
|
||||
self.hidden_size, self.num_kv_heads * self.head_dim,
|
||||
bias=False, quant_config=quant_config,
|
||||
prefix=f"{prefix}.v_proj")
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.num_heads * self.head_dim, self.hidden_size,
|
||||
bias=False, quant_config=quant_config,
|
||||
@@ -559,18 +590,34 @@ class Qwen3_5FullAttention(nn.Module):
|
||||
q = qg[:, :, :self.head_dim].reshape(total_tokens, -1)
|
||||
gate = qg[:, :, self.head_dim:].reshape(total_tokens, -1)
|
||||
|
||||
k, _ = self.k_proj(hidden_states) # (total, local_kv_dim)
|
||||
k, _ = self.k_proj(hidden_states) # (total, proj_kv_heads * head_dim)
|
||||
v, _ = self.v_proj(hidden_states)
|
||||
|
||||
# Per-head RMSNorm
|
||||
# q_norm on local Q heads
|
||||
q = self.q_norm.forward_cuda(
|
||||
q.view(total_tokens, self.local_num_heads, self.head_dim)
|
||||
.contiguous()).view(total_tokens, -1)
|
||||
|
||||
# GQA-aware TP: select rank-local KV head BEFORE k_norm and rope so
|
||||
# that ixformer kernels always see num_kv_heads=1 (same as 27B path).
|
||||
# Doing k_norm/rope on 2 KV heads (proj_kv_heads=2) triggers ixformer
|
||||
# paths that can produce NaN; restricting to 1 head avoids the issue.
|
||||
if self.q_per_kv_global is not None:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
kv_idx = (tp_rank * self.local_num_heads) // self.q_per_kv_global
|
||||
k = (k.view(total_tokens, self.proj_kv_heads, self.head_dim)
|
||||
[:, kv_idx, :].contiguous()) # (T, head_dim) — 1 head
|
||||
v = (v.view(total_tokens, self.proj_kv_heads, self.head_dim)
|
||||
[:, kv_idx, :].contiguous()) # (T, head_dim) — 1 head
|
||||
|
||||
# k_norm on the (now always 1) rank-local KV head
|
||||
k = self.k_norm.forward_cuda(
|
||||
k.view(total_tokens, self.local_num_kv_heads, self.head_dim)
|
||||
.contiguous()).view(total_tokens, -1)
|
||||
|
||||
# rope: q=(T, local_num_heads*head_dim), k=(T, 1*head_dim) — mirrors 27B
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
|
||||
attn_out = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
|
||||
# Multiply by sigmoid gate before output projection
|
||||
@@ -609,10 +656,130 @@ class Qwen3_5MLP(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MoE sparse block (Qwen3.5-MoE / Qwen3.6-35B-A3B)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class Qwen3_5MoeSparseBlock(nn.Module):
|
||||
"""Replaces Qwen3_5MLP for qwen3_5_moe_text layers.
|
||||
|
||||
FusedMoE is used ONLY for weight storage and loading (create_weights /
|
||||
weight_loader are pure PyTorch). Its forward kernel is bypassed because
|
||||
ixformer on BI-V100 lacks vllm_moe_topk_softmax / vllm_invoke_fused_moe_kernel.
|
||||
Routing and expert computation use a pure-PyTorch loop instead.
|
||||
|
||||
Shared expert uses RowParallelLinear(reduce_results=False) so both paths
|
||||
produce partial (pre-all-reduce) outputs that are combined before a single
|
||||
all-reduce.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_cfg,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
hidden_size = text_cfg.hidden_size
|
||||
self.num_experts = text_cfg.num_experts
|
||||
self.top_k = text_cfg.num_experts_per_tok
|
||||
|
||||
# Router: replicated (small: num_experts outputs)
|
||||
self.gate = ReplicatedLinear(hidden_size, text_cfg.num_experts,
|
||||
bias=False, quant_config=quant_config)
|
||||
|
||||
# FusedMoE: only used for weight storage + weight_loader.
|
||||
# Forward is bypassed — see _pure_pytorch_experts().
|
||||
self.experts = FusedMoE(
|
||||
num_experts=text_cfg.num_experts,
|
||||
top_k=text_cfg.num_experts_per_tok,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=text_cfg.moe_intermediate_size,
|
||||
reduce_results=False, # we do the all-reduce ourselves below
|
||||
renormalize=True,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
# Shared expert: defer all-reduce to combine with routed output first
|
||||
shared_size = text_cfg.shared_expert_intermediate_size
|
||||
self.shared_expert_gate_up = MergedColumnParallelLinear(
|
||||
hidden_size, [shared_size] * 2, bias=False,
|
||||
quant_config=quant_config)
|
||||
self.shared_expert_down = RowParallelLinear(
|
||||
shared_size, hidden_size, bias=False, reduce_results=False,
|
||||
quant_config=quant_config)
|
||||
self.act_fn = SiluAndMul()
|
||||
# Scalar sigmoid gate on shared expert output (same as Qwen2-MoE / Qwen3.5-MoE):
|
||||
# shared_out *= sigmoid(shared_expert_gate(hidden_states))
|
||||
# Without this, shared expert is always fully active → wrong logits.
|
||||
self.shared_expert_gate = ReplicatedLinear(
|
||||
hidden_size, 1, bias=False, quant_config=quant_config)
|
||||
|
||||
def _pure_pytorch_experts(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Pure-PyTorch MoE (ixformer has no MoE kernels on BI-V100).
|
||||
|
||||
w13_weight: (num_experts, 2*inter_per_partition, hidden) [TP-sharded]
|
||||
w2_weight: (num_experts, hidden, inter_per_partition) [TP-sharded]
|
||||
Output is partial (pre-all-reduce), same contract as FusedMoE
|
||||
with reduce_results=False.
|
||||
"""
|
||||
# Routing: softmax → topk → renormalise
|
||||
routing_weights = torch.softmax(router_logits.float(), dim=-1)
|
||||
topk_weights, topk_ids = torch.topk(
|
||||
routing_weights, self.top_k, dim=-1) # (T, top_k)
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
topk_weights = topk_weights.to(hidden_states.dtype)
|
||||
|
||||
out = torch.zeros_like(hidden_states)
|
||||
w13 = self.experts.w13_weight # (E, 2*I, H)
|
||||
w2 = self.experts.w2_weight # (E, H, I)
|
||||
|
||||
for eid in range(self.num_experts):
|
||||
# Tokens routed to this expert
|
||||
mask = (topk_ids == eid) # (T, top_k) bool
|
||||
tok_ids, topk_pos = mask.nonzero(as_tuple=True)
|
||||
if tok_ids.numel() == 0:
|
||||
continue
|
||||
|
||||
tokens = hidden_states[tok_ids] # (n, H)
|
||||
# gate + up projection (ColumnParallel shard)
|
||||
gate_up = F.linear(tokens, w13[eid]) # (n, 2*I)
|
||||
gate, up = gate_up.chunk(2, dim=-1)
|
||||
act = F.silu(gate) * up # (n, I)
|
||||
# down projection (RowParallel shard) — result is partial
|
||||
# F.linear(x, W) = x @ W.T; w2[eid]: (H, I) → x @ W.T = (n,H) ✓
|
||||
expert_out = F.linear(act, w2[eid]) # (n, H)
|
||||
|
||||
weights = topk_weights[tok_ids, topk_pos].unsqueeze(-1)
|
||||
out.index_add_(0, tok_ids, (expert_out * weights).to(out.dtype))
|
||||
|
||||
return out # partial, all-reduce done in forward()
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
routed_out = self._pure_pytorch_experts(hidden_states, router_logits)
|
||||
|
||||
gate_up, _ = self.shared_expert_gate_up(hidden_states)
|
||||
shared_out = self.act_fn(gate_up)
|
||||
shared_out, _ = self.shared_expert_down(shared_out)
|
||||
# Scalar sigmoid gate (Qwen2-MoE / Qwen3.5-MoE style)
|
||||
gate_score, _ = self.shared_expert_gate(hidden_states) # (T, 1)
|
||||
shared_out = shared_out * torch.sigmoid(gate_score)
|
||||
|
||||
out = routed_out + shared_out
|
||||
if self.experts.tp_size > 1:
|
||||
out = tensor_model_parallel_all_reduce(out)
|
||||
return out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Decoder layer (dispatches to GatedDeltaNet or Qwen3_5FullAttention)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Qwen3_5DecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -623,6 +790,7 @@ class Qwen3_5DecoderLayer(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
self.layer_type = layer_type
|
||||
self.input_layernorm = GemmaRMSNorm(text_cfg.hidden_size,
|
||||
eps=text_cfg.rms_norm_eps)
|
||||
@@ -640,12 +808,15 @@ class Qwen3_5DecoderLayer(nn.Module):
|
||||
prefix=f"layers.{layer_idx}.self_attn",
|
||||
)
|
||||
|
||||
self.mlp = Qwen3_5MLP(
|
||||
hidden_size=text_cfg.hidden_size,
|
||||
intermediate_size=text_cfg.intermediate_size,
|
||||
hidden_act=text_cfg.hidden_act,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
if getattr(text_cfg, 'model_type', '') == 'qwen3_5_moe_text':
|
||||
self.mlp = Qwen3_5MoeSparseBlock(text_cfg, quant_config=quant_config)
|
||||
else:
|
||||
self.mlp = Qwen3_5MLP(
|
||||
hidden_size=text_cfg.hidden_size,
|
||||
intermediate_size=text_cfg.intermediate_size,
|
||||
hidden_act=text_cfg.hidden_act,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -673,7 +844,9 @@ class Qwen3_5DecoderLayer(nn.Module):
|
||||
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@@ -860,8 +1033,9 @@ class Qwen3_5ForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
# With chunked prefill, intermediate chunks have seq_groups=None on all
|
||||
# ranks; _apply_logits_processors is guarded against this in
|
||||
# logits_processor.py (patched by patch_xformers_sdpa_seq.py).
|
||||
return self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
@@ -892,12 +1066,9 @@ class Qwen3_5ForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
or name.startswith("model.mtp")):
|
||||
continue
|
||||
|
||||
# Remap checkpoint prefix → module path
|
||||
# Checkpoint: "model.language_model.{rest}" → our module: "model.{rest}"
|
||||
# Checkpoint: "lm_head.weight" → our module: "lm_head.weight"
|
||||
# Prefix remapping: checkpoint may wrap under language_model
|
||||
if name.startswith("model.language_model."):
|
||||
name = "model." + name[len("model.language_model."):]
|
||||
# lm_head is already at top level — no change needed
|
||||
|
||||
# Skip positional embedding caches
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
@@ -931,3 +1102,118 @@ class Qwen3_5ForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Qwen3.6-35B-A3B (Qwen3_5-MoE architecture)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class Qwen3_5MoeForCausalLM(Qwen3_5ForCausalLM):
|
||||
"""Qwen3.6-35B-A3B: same hybrid-attention backbone as 27B, dense MLP
|
||||
replaced by Qwen3_5MoeSparseBlock (256 routed experts + shared expert).
|
||||
Only load_weights differs from the dense variant.
|
||||
"""
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
# Checkpoint key format for this model (transformers Qwen3_5MoeExperts):
|
||||
# mlp.experts.gate_up_proj shape (num_experts, 2*intermediate, hidden)
|
||||
# mlp.experts.down_proj shape (num_experts, hidden, intermediate)
|
||||
# mlp.gate.weight shape (num_experts, hidden) [router]
|
||||
# mlp.shared_expert.{gate,up,down}_proj.weight [shared MLP]
|
||||
# Our FusedMoE stores:
|
||||
# mlp.experts.w13_weight shape (num_experts, 2*intermediate//tp, hidden)
|
||||
# mlp.experts.w2_weight shape (num_experts, hidden, intermediate//tp)
|
||||
# Our shared expert stores:
|
||||
# mlp.shared_expert_gate_up.weight (merged gate+up)
|
||||
# mlp.shared_expert_down.weight
|
||||
|
||||
stacked_params_mapping = [
|
||||
# (param_name, weight_name, shard_id)
|
||||
# shared expert
|
||||
("shared_expert_gate_up", "shared_expert.gate_proj", 0),
|
||||
("shared_expert_gate_up", "shared_expert.up_proj", 1),
|
||||
# linear_attention dense proj (same as 27B)
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
# Skip vision and MTP branches
|
||||
if (name.startswith("model.visual")
|
||||
or name.startswith("mtp.")
|
||||
or name.startswith("model.mtp")):
|
||||
continue
|
||||
|
||||
# Prefix remapping for VL checkpoint (Qwen3_5MoeForConditionalGeneration):
|
||||
# model.language_model.model.{layers,embed_tokens,norm} -> model.{...}
|
||||
# model.language_model.lm_head -> lm_head
|
||||
# Prefix remapping: checkpoint may wrap under language_model
|
||||
if name.startswith("model.language_model."):
|
||||
name = "model." + name[len("model.language_model."):]
|
||||
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
if ".linear_attn.conv1d.weight" in name:
|
||||
name = name.replace(".linear_attn.conv1d.weight",
|
||||
".linear_attn.conv1d_weight")
|
||||
|
||||
# --- Fused routed-expert weights (all experts in one tensor) ---
|
||||
|
||||
if "mlp.experts.gate_up_proj" in name:
|
||||
# loaded_weight: (num_experts, 2*intermediate, hidden)
|
||||
w13_name = name.replace("mlp.experts.gate_up_proj",
|
||||
"mlp.experts.w13_weight")
|
||||
if w13_name not in params_dict:
|
||||
continue
|
||||
param = params_dict[w13_name]
|
||||
n_exp = loaded_weight.shape[0]
|
||||
inter = loaded_weight.shape[1] // 2
|
||||
gate_w = loaded_weight[:, :inter, :].contiguous()
|
||||
up_w = loaded_weight[:, inter:, :].contiguous()
|
||||
for eid in range(n_exp):
|
||||
param.weight_loader(param, gate_w[eid], "w1_weight", "w1", eid)
|
||||
param.weight_loader(param, up_w[eid], "w3_weight", "w3", eid)
|
||||
continue
|
||||
|
||||
if "mlp.experts.down_proj" in name:
|
||||
# loaded_weight: (num_experts, hidden, intermediate)
|
||||
w2_name = name.replace("mlp.experts.down_proj",
|
||||
"mlp.experts.w2_weight")
|
||||
if w2_name not in params_dict:
|
||||
continue
|
||||
param = params_dict[w2_name]
|
||||
n_exp = loaded_weight.shape[0]
|
||||
for eid in range(n_exp):
|
||||
param.weight_loader(param, loaded_weight[eid], "w2_weight", "w2", eid)
|
||||
continue
|
||||
|
||||
# --- Shared expert down_proj rename ---
|
||||
if "mlp.shared_expert.down_proj" in name:
|
||||
name = name.replace("mlp.shared_expert.down_proj",
|
||||
"mlp.shared_expert_down")
|
||||
if name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
continue
|
||||
|
||||
# --- Stacked / standard weights ---
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
if name not in params_dict:
|
||||
break
|
||||
param = params_dict[name]
|
||||
param.weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
if name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
Reference in New Issue
Block a user