fix incorrect MoE step to ensure decoding speed
This commit is contained in:
@@ -733,28 +733,53 @@ class Qwen3_5MoeSparseBlock(nn.Module):
|
|||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||||
topk_weights = topk_weights.to(hidden_states.dtype)
|
topk_weights = topk_weights.to(hidden_states.dtype)
|
||||||
|
|
||||||
out = torch.zeros_like(hidden_states)
|
|
||||||
w13 = self.experts.w13_weight # (E, 2*I, H)
|
w13 = self.experts.w13_weight # (E, 2*I, H)
|
||||||
w2 = self.experts.w2_weight # (E, H, I)
|
w2 = self.experts.w2_weight # (E, H, I)
|
||||||
|
|
||||||
for eid in range(self.num_experts):
|
T = hidden_states.shape[0]
|
||||||
# Tokens routed to this expert
|
if T == 1:
|
||||||
mask = (topk_ids == eid) # (T, top_k) bool
|
# Fast path: single token (decode).
|
||||||
tok_ids, topk_pos = mask.nonzero(as_tuple=True)
|
# Batched GEMM: replace top_k separate F.linear calls with 2 fused ops.
|
||||||
if tok_ids.numel() == 0:
|
# gate_up: 1 large GEMM (1,H) × (H, K*2*I) → (1, K*2*I)
|
||||||
continue
|
# down: 1 einsum (K,I) × (K,H,I) → (K,H)
|
||||||
|
# Total: 3 kernel launches vs previous 16 (top_k*2).
|
||||||
|
eids = topk_ids[0] # (K,)
|
||||||
|
ws = topk_weights[0].to(hidden_states.dtype) # (K,)
|
||||||
|
w13_sel = w13[eids] # (K, 2*I, H)
|
||||||
|
w2_sel = w2[eids] # (K, H, I)
|
||||||
|
|
||||||
tokens = hidden_states[tok_ids] # (n, H)
|
H = hidden_states.shape[-1]
|
||||||
# gate + up projection (ColumnParallel shard)
|
K2I = w13_sel.shape[1] # K * (2*I) after reshape
|
||||||
gate_up = F.linear(tokens, w13[eid]) # (n, 2*I)
|
|
||||||
gate, up = gate_up.chunk(2, dim=-1)
|
|
||||||
act = F.silu(gate) * up # (n, I)
|
|
||||||
# down projection (RowParallel shard) — result is partial
|
|
||||||
# F.linear(x, W) = x @ W.T; w2[eid]: (H, I) → x @ W.T = (n,H) ✓
|
|
||||||
expert_out = F.linear(act, w2[eid]) # (n, H)
|
|
||||||
|
|
||||||
weights = topk_weights[tok_ids, topk_pos].unsqueeze(-1)
|
gate_up = F.linear(
|
||||||
out.index_add_(0, tok_ids, (expert_out * weights).to(out.dtype))
|
hidden_states,
|
||||||
|
w13_sel.reshape(-1, H).contiguous(), # (K*2*I, H)
|
||||||
|
) # (1, K*2*I)
|
||||||
|
gate_up = gate_up.view(self.top_k, -1) # (K, 2*I)
|
||||||
|
gate, up = gate_up.chunk(2, dim=-1) # (K, I) each
|
||||||
|
act = F.silu(gate) * up # (K, I)
|
||||||
|
|
||||||
|
# einsum: result[k,h] = sum_i act[k,i] * w2_sel[k,h,i]
|
||||||
|
expert_out = torch.einsum('ki,khi->kh', act, w2_sel) # (K, H)
|
||||||
|
|
||||||
|
out = (expert_out * ws.unsqueeze(-1)).sum(0, keepdim=True).to(
|
||||||
|
hidden_states.dtype) # (1, H)
|
||||||
|
else:
|
||||||
|
# General path (prefill / multi-seq): loop over unique active experts.
|
||||||
|
# At most T*top_k unique experts, always <= num_experts.
|
||||||
|
out = torch.zeros_like(hidden_states)
|
||||||
|
unique_eids = topk_ids.view(-1).unique().tolist()
|
||||||
|
for eid in unique_eids:
|
||||||
|
eid = int(eid)
|
||||||
|
mask = (topk_ids == eid) # (T, top_k)
|
||||||
|
tok_ids, topk_pos = mask.nonzero(as_tuple=True)
|
||||||
|
tokens = hidden_states[tok_ids] # (n, H)
|
||||||
|
gate_up = F.linear(tokens, w13[eid]) # (n, 2*I)
|
||||||
|
gate, up = gate_up.chunk(2, dim=-1)
|
||||||
|
act = F.silu(gate) * up # (n, I)
|
||||||
|
expert_out = F.linear(act, w2[eid]) # (n, H)
|
||||||
|
weights = topk_weights[tok_ids, topk_pos].unsqueeze(-1)
|
||||||
|
out.index_add_(0, tok_ids, (expert_out * weights).to(out.dtype))
|
||||||
|
|
||||||
return out # partial, all-reduce done in forward()
|
return out # partial, all-reduce done in forward()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user