some op overhead optimization
This commit is contained in:
@@ -420,9 +420,6 @@ class GatedDeltaNet(nn.Module):
|
||||
|
||||
else:
|
||||
# Decode: one token per sequence
|
||||
with open("/tmp/vllm_decode_debug.log", "a") as _f:
|
||||
_f.write(f"[deltanet decode] layer={self.layer_idx} num_seqs={hidden_states.shape[0]}\n")
|
||||
_f.flush()
|
||||
num_seqs = hidden_states.shape[0]
|
||||
weight_2d = self.conv1d_weight.squeeze(1)
|
||||
|
||||
@@ -452,17 +449,47 @@ class GatedDeltaNet(nn.Module):
|
||||
q = q.repeat_interleave(self.head_expand_ratio, dim=2)
|
||||
k = k.repeat_interleave(self.head_expand_ratio, dim=2)
|
||||
|
||||
core_out, last_state = _torch_recurrent_gated_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=temporal_state,
|
||||
output_final_state=True,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
# Inlined decode recurrent step (seq_len=1).
|
||||
# Replaces _torch_recurrent_gated_delta_rule to avoid 5 transpose+
|
||||
# contiguous+float32 copies, core_out allocation, and Python loop.
|
||||
# Uses bmm/baddbmm_ to eliminate 3 large (B,H,k,v) intermediate tensors.
|
||||
# temporal_state: (B, H_v, k_dim, v_dim) float32 — updated in-place.
|
||||
orig_dtype = q.dtype
|
||||
_scale = self.head_k_dim ** -0.5
|
||||
|
||||
q_t = _l2norm(q.squeeze(1)).float() * _scale # (B, H_v, k_dim)
|
||||
k_t = _l2norm(k.squeeze(1)).float() # (B, H_v, k_dim)
|
||||
v_t = v.squeeze(1).float() # (B, H_v, v_dim)
|
||||
g_t = g.squeeze(1).float().exp_() # (B, H_v)
|
||||
bt = beta.squeeze(1).float() # (B, H_v)
|
||||
|
||||
# Decay state in-place: (B, H_v, k_dim, v_dim) *= scalar per head
|
||||
temporal_state.mul_(g_t[:, :, None, None])
|
||||
|
||||
# Reshape to batched-matmul layout: (B*H_v, k_dim, v_dim)
|
||||
ts_flat = temporal_state.view(-1, self.head_k_dim, self.head_v_dim)
|
||||
BH = ts_flat.shape[0]
|
||||
|
||||
# kv_mem = k_t @ temporal_state shape: (B*H_v, 1, k_dim) @ (B*H_v, k_dim, v_dim)
|
||||
kv_mem = torch.bmm(
|
||||
k_t.view(BH, 1, self.head_k_dim), ts_flat
|
||||
).view(num_seqs, local_num_v, self.head_v_dim) # (B, H_v, v_dim)
|
||||
|
||||
delta = (v_t - kv_mem) * bt[:, :, None] # (B, H_v, v_dim)
|
||||
|
||||
# State update: temporal_state += outer(k_t, delta) fused, no intermediate
|
||||
ts_flat.baddbmm_(
|
||||
k_t.view(BH, self.head_k_dim, 1),
|
||||
delta.view(BH, 1, self.head_v_dim),
|
||||
)
|
||||
if last_state is not None:
|
||||
temporal_state.copy_(last_state)
|
||||
|
||||
# Output: core_out = q_t @ updated temporal_state
|
||||
core_out = torch.bmm(
|
||||
q_t.view(BH, 1, self.head_k_dim), ts_flat
|
||||
).view(num_seqs, local_num_v, self.head_v_dim).to(orig_dtype)
|
||||
# core_out: (B, H_v, v_dim) = (num_seqs, local_num_v, head_v_dim) already
|
||||
|
||||
z = z_all.reshape(num_seqs, local_num_v, self.head_v_dim)
|
||||
core_out = core_out.reshape(num_seqs, local_num_v, self.head_v_dim)
|
||||
normed = self.norm(
|
||||
core_out.reshape(-1, self.head_v_dim),
|
||||
z.reshape(-1, self.head_v_dim))
|
||||
@@ -740,27 +767,26 @@ class Qwen3_5MoeSparseBlock(nn.Module):
|
||||
if T == 1:
|
||||
# Fast path: single token (decode).
|
||||
# Batched GEMM: replace top_k separate F.linear calls with 2 fused ops.
|
||||
# gate_up: 1 large GEMM (1,H) × (H, K*2*I) → (1, K*2*I)
|
||||
# down: 1 einsum (K,I) × (K,H,I) → (K,H)
|
||||
# gate_up: 1 large GEMM (1,H) × (K*2*I,H)^T → (1, K*2*I)
|
||||
# down: 1 bmm (K,H,I) @ (K,I,1) → (K,H)
|
||||
# Total: 3 kernel launches vs previous 16 (top_k*2).
|
||||
eids = topk_ids[0] # (K,)
|
||||
ws = topk_weights[0].to(hidden_states.dtype) # (K,)
|
||||
w13_sel = w13[eids] # (K, 2*I, H)
|
||||
w2_sel = w2[eids] # (K, H, I)
|
||||
|
||||
H = hidden_states.shape[-1]
|
||||
K2I = w13_sel.shape[1] # K * (2*I) after reshape
|
||||
H = hidden_states.shape[-1]
|
||||
|
||||
gate_up = F.linear(
|
||||
hidden_states,
|
||||
w13_sel.reshape(-1, H).contiguous(), # (K*2*I, H)
|
||||
w13_sel.reshape(-1, H), # (K*2*I, H) — contiguous after indexing
|
||||
) # (1, K*2*I)
|
||||
gate_up = gate_up.view(self.top_k, -1) # (K, 2*I)
|
||||
gate, up = gate_up.chunk(2, dim=-1) # (K, I) each
|
||||
act = F.silu(gate) * up # (K, I)
|
||||
|
||||
# einsum: result[k,h] = sum_i act[k,i] * w2_sel[k,h,i]
|
||||
expert_out = torch.einsum('ki,khi->kh', act, w2_sel) # (K, H)
|
||||
# bmm: (K,H,I) @ (K,I,1) → (K,H,1) → (K,H)
|
||||
expert_out = torch.bmm(w2_sel, act.unsqueeze(-1)).squeeze(-1) # (K, H)
|
||||
|
||||
out = (expert_out * ws.unsqueeze(-1)).sum(0, keepdim=True).to(
|
||||
hidden_states.dtype) # (1, H)
|
||||
|
||||
Reference in New Issue
Block a user