some op overhead optimization

This commit is contained in:
2026-06-19 11:19:39 +08:00
parent 47a4d9e72a
commit b5806731e0

View File

@@ -420,9 +420,6 @@ class GatedDeltaNet(nn.Module):
else: else:
# Decode: one token per sequence # Decode: one token per sequence
with open("/tmp/vllm_decode_debug.log", "a") as _f:
_f.write(f"[deltanet decode] layer={self.layer_idx} num_seqs={hidden_states.shape[0]}\n")
_f.flush()
num_seqs = hidden_states.shape[0] num_seqs = hidden_states.shape[0]
weight_2d = self.conv1d_weight.squeeze(1) weight_2d = self.conv1d_weight.squeeze(1)
@@ -452,17 +449,47 @@ class GatedDeltaNet(nn.Module):
q = q.repeat_interleave(self.head_expand_ratio, dim=2) q = q.repeat_interleave(self.head_expand_ratio, dim=2)
k = k.repeat_interleave(self.head_expand_ratio, dim=2) k = k.repeat_interleave(self.head_expand_ratio, dim=2)
core_out, last_state = _torch_recurrent_gated_delta_rule( # Inlined decode recurrent step (seq_len=1).
q, k, v, g, beta, # Replaces _torch_recurrent_gated_delta_rule to avoid 5 transpose+
initial_state=temporal_state, # contiguous+float32 copies, core_out allocation, and Python loop.
output_final_state=True, # Uses bmm/baddbmm_ to eliminate 3 large (B,H,k,v) intermediate tensors.
use_qk_l2norm_in_kernel=True, # temporal_state: (B, H_v, k_dim, v_dim) float32 — updated in-place.
orig_dtype = q.dtype
_scale = self.head_k_dim ** -0.5
q_t = _l2norm(q.squeeze(1)).float() * _scale # (B, H_v, k_dim)
k_t = _l2norm(k.squeeze(1)).float() # (B, H_v, k_dim)
v_t = v.squeeze(1).float() # (B, H_v, v_dim)
g_t = g.squeeze(1).float().exp_() # (B, H_v)
bt = beta.squeeze(1).float() # (B, H_v)
# Decay state in-place: (B, H_v, k_dim, v_dim) *= scalar per head
temporal_state.mul_(g_t[:, :, None, None])
# Reshape to batched-matmul layout: (B*H_v, k_dim, v_dim)
ts_flat = temporal_state.view(-1, self.head_k_dim, self.head_v_dim)
BH = ts_flat.shape[0]
# kv_mem = k_t @ temporal_state shape: (B*H_v, 1, k_dim) @ (B*H_v, k_dim, v_dim)
kv_mem = torch.bmm(
k_t.view(BH, 1, self.head_k_dim), ts_flat
).view(num_seqs, local_num_v, self.head_v_dim) # (B, H_v, v_dim)
delta = (v_t - kv_mem) * bt[:, :, None] # (B, H_v, v_dim)
# State update: temporal_state += outer(k_t, delta) fused, no intermediate
ts_flat.baddbmm_(
k_t.view(BH, self.head_k_dim, 1),
delta.view(BH, 1, self.head_v_dim),
) )
if last_state is not None:
temporal_state.copy_(last_state) # Output: core_out = q_t @ updated temporal_state
core_out = torch.bmm(
q_t.view(BH, 1, self.head_k_dim), ts_flat
).view(num_seqs, local_num_v, self.head_v_dim).to(orig_dtype)
# core_out: (B, H_v, v_dim) = (num_seqs, local_num_v, head_v_dim) already
z = z_all.reshape(num_seqs, local_num_v, self.head_v_dim) z = z_all.reshape(num_seqs, local_num_v, self.head_v_dim)
core_out = core_out.reshape(num_seqs, local_num_v, self.head_v_dim)
normed = self.norm( normed = self.norm(
core_out.reshape(-1, self.head_v_dim), core_out.reshape(-1, self.head_v_dim),
z.reshape(-1, self.head_v_dim)) z.reshape(-1, self.head_v_dim))
@@ -740,27 +767,26 @@ class Qwen3_5MoeSparseBlock(nn.Module):
if T == 1: if T == 1:
# Fast path: single token (decode). # Fast path: single token (decode).
# Batched GEMM: replace top_k separate F.linear calls with 2 fused ops. # Batched GEMM: replace top_k separate F.linear calls with 2 fused ops.
# gate_up: 1 large GEMM (1,H) × (H, K*2*I) → (1, K*2*I) # gate_up: 1 large GEMM (1,H) × (K*2*I,H)^T → (1, K*2*I)
# down: 1 einsum (K,I) × (K,H,I) → (K,H) # down: 1 bmm (K,H,I) @ (K,I,1) → (K,H)
# Total: 3 kernel launches vs previous 16 (top_k*2). # Total: 3 kernel launches vs previous 16 (top_k*2).
eids = topk_ids[0] # (K,) eids = topk_ids[0] # (K,)
ws = topk_weights[0].to(hidden_states.dtype) # (K,) ws = topk_weights[0].to(hidden_states.dtype) # (K,)
w13_sel = w13[eids] # (K, 2*I, H) w13_sel = w13[eids] # (K, 2*I, H)
w2_sel = w2[eids] # (K, H, I) w2_sel = w2[eids] # (K, H, I)
H = hidden_states.shape[-1] H = hidden_states.shape[-1]
K2I = w13_sel.shape[1] # K * (2*I) after reshape
gate_up = F.linear( gate_up = F.linear(
hidden_states, hidden_states,
w13_sel.reshape(-1, H).contiguous(), # (K*2*I, H) w13_sel.reshape(-1, H), # (K*2*I, H) — contiguous after indexing
) # (1, K*2*I) ) # (1, K*2*I)
gate_up = gate_up.view(self.top_k, -1) # (K, 2*I) gate_up = gate_up.view(self.top_k, -1) # (K, 2*I)
gate, up = gate_up.chunk(2, dim=-1) # (K, I) each gate, up = gate_up.chunk(2, dim=-1) # (K, I) each
act = F.silu(gate) * up # (K, I) act = F.silu(gate) * up # (K, I)
# einsum: result[k,h] = sum_i act[k,i] * w2_sel[k,h,i] # bmm: (K,H,I) @ (K,I,1) → (K,H,1) → (K,H)
expert_out = torch.einsum('ki,khi->kh', act, w2_sel) # (K, H) expert_out = torch.bmm(w2_sel, act.unsqueeze(-1)).squeeze(-1) # (K, H)
out = (expert_out * ws.unsqueeze(-1)).sum(0, keepdim=True).to( out = (expert_out * ws.unsqueeze(-1)).sum(0, keepdim=True).to(
hidden_states.dtype) # (1, H) hidden_states.dtype) # (1, H)