fix incorrect MoE step to ensure decoding speed
This commit is contained in:
@@ -733,28 +733,53 @@ class Qwen3_5MoeSparseBlock(nn.Module):
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
topk_weights = topk_weights.to(hidden_states.dtype)
|
||||
|
||||
out = torch.zeros_like(hidden_states)
|
||||
w13 = self.experts.w13_weight # (E, 2*I, H)
|
||||
w2 = self.experts.w2_weight # (E, H, I)
|
||||
|
||||
for eid in range(self.num_experts):
|
||||
# Tokens routed to this expert
|
||||
mask = (topk_ids == eid) # (T, top_k) bool
|
||||
tok_ids, topk_pos = mask.nonzero(as_tuple=True)
|
||||
if tok_ids.numel() == 0:
|
||||
continue
|
||||
T = hidden_states.shape[0]
|
||||
if T == 1:
|
||||
# Fast path: single token (decode).
|
||||
# Batched GEMM: replace top_k separate F.linear calls with 2 fused ops.
|
||||
# gate_up: 1 large GEMM (1,H) × (H, K*2*I) → (1, K*2*I)
|
||||
# down: 1 einsum (K,I) × (K,H,I) → (K,H)
|
||||
# Total: 3 kernel launches vs previous 16 (top_k*2).
|
||||
eids = topk_ids[0] # (K,)
|
||||
ws = topk_weights[0].to(hidden_states.dtype) # (K,)
|
||||
w13_sel = w13[eids] # (K, 2*I, H)
|
||||
w2_sel = w2[eids] # (K, H, I)
|
||||
|
||||
tokens = hidden_states[tok_ids] # (n, H)
|
||||
# gate + up projection (ColumnParallel shard)
|
||||
gate_up = F.linear(tokens, w13[eid]) # (n, 2*I)
|
||||
gate, up = gate_up.chunk(2, dim=-1)
|
||||
act = F.silu(gate) * up # (n, I)
|
||||
# down projection (RowParallel shard) — result is partial
|
||||
# F.linear(x, W) = x @ W.T; w2[eid]: (H, I) → x @ W.T = (n,H) ✓
|
||||
expert_out = F.linear(act, w2[eid]) # (n, H)
|
||||
H = hidden_states.shape[-1]
|
||||
K2I = w13_sel.shape[1] # K * (2*I) after reshape
|
||||
|
||||
weights = topk_weights[tok_ids, topk_pos].unsqueeze(-1)
|
||||
out.index_add_(0, tok_ids, (expert_out * weights).to(out.dtype))
|
||||
gate_up = F.linear(
|
||||
hidden_states,
|
||||
w13_sel.reshape(-1, H).contiguous(), # (K*2*I, H)
|
||||
) # (1, K*2*I)
|
||||
gate_up = gate_up.view(self.top_k, -1) # (K, 2*I)
|
||||
gate, up = gate_up.chunk(2, dim=-1) # (K, I) each
|
||||
act = F.silu(gate) * up # (K, I)
|
||||
|
||||
# einsum: result[k,h] = sum_i act[k,i] * w2_sel[k,h,i]
|
||||
expert_out = torch.einsum('ki,khi->kh', act, w2_sel) # (K, H)
|
||||
|
||||
out = (expert_out * ws.unsqueeze(-1)).sum(0, keepdim=True).to(
|
||||
hidden_states.dtype) # (1, H)
|
||||
else:
|
||||
# General path (prefill / multi-seq): loop over unique active experts.
|
||||
# At most T*top_k unique experts, always <= num_experts.
|
||||
out = torch.zeros_like(hidden_states)
|
||||
unique_eids = topk_ids.view(-1).unique().tolist()
|
||||
for eid in unique_eids:
|
||||
eid = int(eid)
|
||||
mask = (topk_ids == eid) # (T, top_k)
|
||||
tok_ids, topk_pos = mask.nonzero(as_tuple=True)
|
||||
tokens = hidden_states[tok_ids] # (n, H)
|
||||
gate_up = F.linear(tokens, w13[eid]) # (n, 2*I)
|
||||
gate, up = gate_up.chunk(2, dim=-1)
|
||||
act = F.silu(gate) * up # (n, I)
|
||||
expert_out = F.linear(act, w2[eid]) # (n, H)
|
||||
weights = topk_weights[tok_ids, topk_pos].unsqueeze(-1)
|
||||
out.index_add_(0, tok_ids, (expert_out * weights).to(out.dtype))
|
||||
|
||||
return out # partial, all-reduce done in forward()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user