diff --git a/qwen3_6_scripts/qwen3_5.py b/qwen3_6_scripts/qwen3_5.py index 60d3d62..03dde1c 100644 --- a/qwen3_6_scripts/qwen3_5.py +++ b/qwen3_6_scripts/qwen3_5.py @@ -733,28 +733,53 @@ class Qwen3_5MoeSparseBlock(nn.Module): topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) topk_weights = topk_weights.to(hidden_states.dtype) - out = torch.zeros_like(hidden_states) w13 = self.experts.w13_weight # (E, 2*I, H) w2 = self.experts.w2_weight # (E, H, I) - for eid in range(self.num_experts): - # Tokens routed to this expert - mask = (topk_ids == eid) # (T, top_k) bool - tok_ids, topk_pos = mask.nonzero(as_tuple=True) - if tok_ids.numel() == 0: - continue + T = hidden_states.shape[0] + if T == 1: + # Fast path: single token (decode). + # Batched GEMM: replace top_k separate F.linear calls with 2 fused ops. + # gate_up: 1 large GEMM (1,H) × (H, K*2*I) → (1, K*2*I) + # down: 1 einsum (K,I) × (K,H,I) → (K,H) + # Total: 3 kernel launches vs previous 16 (top_k*2). + eids = topk_ids[0] # (K,) + ws = topk_weights[0].to(hidden_states.dtype) # (K,) + w13_sel = w13[eids] # (K, 2*I, H) + w2_sel = w2[eids] # (K, H, I) - tokens = hidden_states[tok_ids] # (n, H) - # gate + up projection (ColumnParallel shard) - gate_up = F.linear(tokens, w13[eid]) # (n, 2*I) - gate, up = gate_up.chunk(2, dim=-1) - act = F.silu(gate) * up # (n, I) - # down projection (RowParallel shard) — result is partial - # F.linear(x, W) = x @ W.T; w2[eid]: (H, I) → x @ W.T = (n,H) ✓ - expert_out = F.linear(act, w2[eid]) # (n, H) + H = hidden_states.shape[-1] + K2I = w13_sel.shape[1] # K * (2*I) after reshape - weights = topk_weights[tok_ids, topk_pos].unsqueeze(-1) - out.index_add_(0, tok_ids, (expert_out * weights).to(out.dtype)) + gate_up = F.linear( + hidden_states, + w13_sel.reshape(-1, H).contiguous(), # (K*2*I, H) + ) # (1, K*2*I) + gate_up = gate_up.view(self.top_k, -1) # (K, 2*I) + gate, up = gate_up.chunk(2, dim=-1) # (K, I) each + act = F.silu(gate) * up # (K, I) + + # einsum: result[k,h] = sum_i act[k,i] * w2_sel[k,h,i] + expert_out = torch.einsum('ki,khi->kh', act, w2_sel) # (K, H) + + out = (expert_out * ws.unsqueeze(-1)).sum(0, keepdim=True).to( + hidden_states.dtype) # (1, H) + else: + # General path (prefill / multi-seq): loop over unique active experts. + # At most T*top_k unique experts, always <= num_experts. + out = torch.zeros_like(hidden_states) + unique_eids = topk_ids.view(-1).unique().tolist() + for eid in unique_eids: + eid = int(eid) + mask = (topk_ids == eid) # (T, top_k) + tok_ids, topk_pos = mask.nonzero(as_tuple=True) + tokens = hidden_states[tok_ids] # (n, H) + gate_up = F.linear(tokens, w13[eid]) # (n, 2*I) + gate, up = gate_up.chunk(2, dim=-1) + act = F.silu(gate) * up # (n, I) + expert_out = F.linear(act, w2[eid]) # (n, H) + weights = topk_weights[tok_ids, topk_pos].unsqueeze(-1) + out.index_add_(0, tok_ids, (expert_out * weights).to(out.dtype)) return out # partial, all-reduce done in forward()