fix incorrect MoE step to ensure decoding speed

This commit is contained in:
2026-06-12 11:44:50 +08:00
parent 629f878c28
commit 50e3a05fb0

View File

@@ -733,28 +733,53 @@ class Qwen3_5MoeSparseBlock(nn.Module):
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights = topk_weights.to(hidden_states.dtype)
out = torch.zeros_like(hidden_states)
w13 = self.experts.w13_weight # (E, 2*I, H)
w2 = self.experts.w2_weight # (E, H, I)
for eid in range(self.num_experts):
# Tokens routed to this expert
mask = (topk_ids == eid) # (T, top_k) bool
tok_ids, topk_pos = mask.nonzero(as_tuple=True)
if tok_ids.numel() == 0:
continue
T = hidden_states.shape[0]
if T == 1:
# Fast path: single token (decode).
# Batched GEMM: replace top_k separate F.linear calls with 2 fused ops.
# gate_up: 1 large GEMM (1,H) × (H, K*2*I) → (1, K*2*I)
# down: 1 einsum (K,I) × (K,H,I) → (K,H)
# Total: 3 kernel launches vs previous 16 (top_k*2).
eids = topk_ids[0] # (K,)
ws = topk_weights[0].to(hidden_states.dtype) # (K,)
w13_sel = w13[eids] # (K, 2*I, H)
w2_sel = w2[eids] # (K, H, I)
tokens = hidden_states[tok_ids] # (n, H)
# gate + up projection (ColumnParallel shard)
gate_up = F.linear(tokens, w13[eid]) # (n, 2*I)
gate, up = gate_up.chunk(2, dim=-1)
act = F.silu(gate) * up # (n, I)
# down projection (RowParallel shard) — result is partial
# F.linear(x, W) = x @ W.T; w2[eid]: (H, I) → x @ W.T = (n,H) ✓
expert_out = F.linear(act, w2[eid]) # (n, H)
H = hidden_states.shape[-1]
K2I = w13_sel.shape[1] # K * (2*I) after reshape
weights = topk_weights[tok_ids, topk_pos].unsqueeze(-1)
out.index_add_(0, tok_ids, (expert_out * weights).to(out.dtype))
gate_up = F.linear(
hidden_states,
w13_sel.reshape(-1, H).contiguous(), # (K*2*I, H)
) # (1, K*2*I)
gate_up = gate_up.view(self.top_k, -1) # (K, 2*I)
gate, up = gate_up.chunk(2, dim=-1) # (K, I) each
act = F.silu(gate) * up # (K, I)
# einsum: result[k,h] = sum_i act[k,i] * w2_sel[k,h,i]
expert_out = torch.einsum('ki,khi->kh', act, w2_sel) # (K, H)
out = (expert_out * ws.unsqueeze(-1)).sum(0, keepdim=True).to(
hidden_states.dtype) # (1, H)
else:
# General path (prefill / multi-seq): loop over unique active experts.
# At most T*top_k unique experts, always <= num_experts.
out = torch.zeros_like(hidden_states)
unique_eids = topk_ids.view(-1).unique().tolist()
for eid in unique_eids:
eid = int(eid)
mask = (topk_ids == eid) # (T, top_k)
tok_ids, topk_pos = mask.nonzero(as_tuple=True)
tokens = hidden_states[tok_ids] # (n, H)
gate_up = F.linear(tokens, w13[eid]) # (n, 2*I)
gate, up = gate_up.chunk(2, dim=-1)
act = F.silu(gate) * up # (n, I)
expert_out = F.linear(act, w2[eid]) # (n, H)
weights = topk_weights[tok_ids, topk_pos].unsqueeze(-1)
out.index_add_(0, tok_ids, (expert_out * weights).to(out.dtype))
return out # partial, all-reduce done in forward()