diff --git a/qwen3_6_scripts/qwen3_5.py b/qwen3_6_scripts/qwen3_5.py index 03dde1c..617416b 100644 --- a/qwen3_6_scripts/qwen3_5.py +++ b/qwen3_6_scripts/qwen3_5.py @@ -420,9 +420,6 @@ class GatedDeltaNet(nn.Module): else: # Decode: one token per sequence - with open("/tmp/vllm_decode_debug.log", "a") as _f: - _f.write(f"[deltanet decode] layer={self.layer_idx} num_seqs={hidden_states.shape[0]}\n") - _f.flush() num_seqs = hidden_states.shape[0] weight_2d = self.conv1d_weight.squeeze(1) @@ -452,17 +449,47 @@ class GatedDeltaNet(nn.Module): q = q.repeat_interleave(self.head_expand_ratio, dim=2) k = k.repeat_interleave(self.head_expand_ratio, dim=2) - core_out, last_state = _torch_recurrent_gated_delta_rule( - q, k, v, g, beta, - initial_state=temporal_state, - output_final_state=True, - use_qk_l2norm_in_kernel=True, + # Inlined decode recurrent step (seq_len=1). + # Replaces _torch_recurrent_gated_delta_rule to avoid 5 transpose+ + # contiguous+float32 copies, core_out allocation, and Python loop. + # Uses bmm/baddbmm_ to eliminate 3 large (B,H,k,v) intermediate tensors. + # temporal_state: (B, H_v, k_dim, v_dim) float32 — updated in-place. + orig_dtype = q.dtype + _scale = self.head_k_dim ** -0.5 + + q_t = _l2norm(q.squeeze(1)).float() * _scale # (B, H_v, k_dim) + k_t = _l2norm(k.squeeze(1)).float() # (B, H_v, k_dim) + v_t = v.squeeze(1).float() # (B, H_v, v_dim) + g_t = g.squeeze(1).float().exp_() # (B, H_v) + bt = beta.squeeze(1).float() # (B, H_v) + + # Decay state in-place: (B, H_v, k_dim, v_dim) *= scalar per head + temporal_state.mul_(g_t[:, :, None, None]) + + # Reshape to batched-matmul layout: (B*H_v, k_dim, v_dim) + ts_flat = temporal_state.view(-1, self.head_k_dim, self.head_v_dim) + BH = ts_flat.shape[0] + + # kv_mem = k_t @ temporal_state shape: (B*H_v, 1, k_dim) @ (B*H_v, k_dim, v_dim) + kv_mem = torch.bmm( + k_t.view(BH, 1, self.head_k_dim), ts_flat + ).view(num_seqs, local_num_v, self.head_v_dim) # (B, H_v, v_dim) + + delta = (v_t - kv_mem) * bt[:, :, None] # (B, H_v, v_dim) + + # State update: temporal_state += outer(k_t, delta) fused, no intermediate + ts_flat.baddbmm_( + k_t.view(BH, self.head_k_dim, 1), + delta.view(BH, 1, self.head_v_dim), ) - if last_state is not None: - temporal_state.copy_(last_state) + + # Output: core_out = q_t @ updated temporal_state + core_out = torch.bmm( + q_t.view(BH, 1, self.head_k_dim), ts_flat + ).view(num_seqs, local_num_v, self.head_v_dim).to(orig_dtype) + # core_out: (B, H_v, v_dim) = (num_seqs, local_num_v, head_v_dim) already z = z_all.reshape(num_seqs, local_num_v, self.head_v_dim) - core_out = core_out.reshape(num_seqs, local_num_v, self.head_v_dim) normed = self.norm( core_out.reshape(-1, self.head_v_dim), z.reshape(-1, self.head_v_dim)) @@ -740,27 +767,26 @@ class Qwen3_5MoeSparseBlock(nn.Module): if T == 1: # Fast path: single token (decode). # Batched GEMM: replace top_k separate F.linear calls with 2 fused ops. - # gate_up: 1 large GEMM (1,H) × (H, K*2*I) → (1, K*2*I) - # down: 1 einsum (K,I) × (K,H,I) → (K,H) + # gate_up: 1 large GEMM (1,H) × (K*2*I,H)^T → (1, K*2*I) + # down: 1 bmm (K,H,I) @ (K,I,1) → (K,H) # Total: 3 kernel launches vs previous 16 (top_k*2). eids = topk_ids[0] # (K,) ws = topk_weights[0].to(hidden_states.dtype) # (K,) w13_sel = w13[eids] # (K, 2*I, H) w2_sel = w2[eids] # (K, H, I) - H = hidden_states.shape[-1] - K2I = w13_sel.shape[1] # K * (2*I) after reshape + H = hidden_states.shape[-1] gate_up = F.linear( hidden_states, - w13_sel.reshape(-1, H).contiguous(), # (K*2*I, H) + w13_sel.reshape(-1, H), # (K*2*I, H) — contiguous after indexing ) # (1, K*2*I) gate_up = gate_up.view(self.top_k, -1) # (K, 2*I) gate, up = gate_up.chunk(2, dim=-1) # (K, I) each act = F.silu(gate) * up # (K, I) - # einsum: result[k,h] = sum_i act[k,i] * w2_sel[k,h,i] - expert_out = torch.einsum('ki,khi->kh', act, w2_sel) # (K, H) + # bmm: (K,H,I) @ (K,I,1) → (K,H,1) → (K,H) + expert_out = torch.bmm(w2_sel, act.unsqueeze(-1)).squeeze(-1) # (K, H) out = (expert_out * ws.unsqueeze(-1)).sum(0, keepdim=True).to( hidden_states.dtype) # (1, H)