some op overhead optimization
This commit is contained in:
@@ -420,9 +420,6 @@ class GatedDeltaNet(nn.Module):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
# Decode: one token per sequence
|
# Decode: one token per sequence
|
||||||
with open("/tmp/vllm_decode_debug.log", "a") as _f:
|
|
||||||
_f.write(f"[deltanet decode] layer={self.layer_idx} num_seqs={hidden_states.shape[0]}\n")
|
|
||||||
_f.flush()
|
|
||||||
num_seqs = hidden_states.shape[0]
|
num_seqs = hidden_states.shape[0]
|
||||||
weight_2d = self.conv1d_weight.squeeze(1)
|
weight_2d = self.conv1d_weight.squeeze(1)
|
||||||
|
|
||||||
@@ -452,17 +449,47 @@ class GatedDeltaNet(nn.Module):
|
|||||||
q = q.repeat_interleave(self.head_expand_ratio, dim=2)
|
q = q.repeat_interleave(self.head_expand_ratio, dim=2)
|
||||||
k = k.repeat_interleave(self.head_expand_ratio, dim=2)
|
k = k.repeat_interleave(self.head_expand_ratio, dim=2)
|
||||||
|
|
||||||
core_out, last_state = _torch_recurrent_gated_delta_rule(
|
# Inlined decode recurrent step (seq_len=1).
|
||||||
q, k, v, g, beta,
|
# Replaces _torch_recurrent_gated_delta_rule to avoid 5 transpose+
|
||||||
initial_state=temporal_state,
|
# contiguous+float32 copies, core_out allocation, and Python loop.
|
||||||
output_final_state=True,
|
# Uses bmm/baddbmm_ to eliminate 3 large (B,H,k,v) intermediate tensors.
|
||||||
use_qk_l2norm_in_kernel=True,
|
# temporal_state: (B, H_v, k_dim, v_dim) float32 — updated in-place.
|
||||||
|
orig_dtype = q.dtype
|
||||||
|
_scale = self.head_k_dim ** -0.5
|
||||||
|
|
||||||
|
q_t = _l2norm(q.squeeze(1)).float() * _scale # (B, H_v, k_dim)
|
||||||
|
k_t = _l2norm(k.squeeze(1)).float() # (B, H_v, k_dim)
|
||||||
|
v_t = v.squeeze(1).float() # (B, H_v, v_dim)
|
||||||
|
g_t = g.squeeze(1).float().exp_() # (B, H_v)
|
||||||
|
bt = beta.squeeze(1).float() # (B, H_v)
|
||||||
|
|
||||||
|
# Decay state in-place: (B, H_v, k_dim, v_dim) *= scalar per head
|
||||||
|
temporal_state.mul_(g_t[:, :, None, None])
|
||||||
|
|
||||||
|
# Reshape to batched-matmul layout: (B*H_v, k_dim, v_dim)
|
||||||
|
ts_flat = temporal_state.view(-1, self.head_k_dim, self.head_v_dim)
|
||||||
|
BH = ts_flat.shape[0]
|
||||||
|
|
||||||
|
# kv_mem = k_t @ temporal_state shape: (B*H_v, 1, k_dim) @ (B*H_v, k_dim, v_dim)
|
||||||
|
kv_mem = torch.bmm(
|
||||||
|
k_t.view(BH, 1, self.head_k_dim), ts_flat
|
||||||
|
).view(num_seqs, local_num_v, self.head_v_dim) # (B, H_v, v_dim)
|
||||||
|
|
||||||
|
delta = (v_t - kv_mem) * bt[:, :, None] # (B, H_v, v_dim)
|
||||||
|
|
||||||
|
# State update: temporal_state += outer(k_t, delta) fused, no intermediate
|
||||||
|
ts_flat.baddbmm_(
|
||||||
|
k_t.view(BH, self.head_k_dim, 1),
|
||||||
|
delta.view(BH, 1, self.head_v_dim),
|
||||||
)
|
)
|
||||||
if last_state is not None:
|
|
||||||
temporal_state.copy_(last_state)
|
# Output: core_out = q_t @ updated temporal_state
|
||||||
|
core_out = torch.bmm(
|
||||||
|
q_t.view(BH, 1, self.head_k_dim), ts_flat
|
||||||
|
).view(num_seqs, local_num_v, self.head_v_dim).to(orig_dtype)
|
||||||
|
# core_out: (B, H_v, v_dim) = (num_seqs, local_num_v, head_v_dim) already
|
||||||
|
|
||||||
z = z_all.reshape(num_seqs, local_num_v, self.head_v_dim)
|
z = z_all.reshape(num_seqs, local_num_v, self.head_v_dim)
|
||||||
core_out = core_out.reshape(num_seqs, local_num_v, self.head_v_dim)
|
|
||||||
normed = self.norm(
|
normed = self.norm(
|
||||||
core_out.reshape(-1, self.head_v_dim),
|
core_out.reshape(-1, self.head_v_dim),
|
||||||
z.reshape(-1, self.head_v_dim))
|
z.reshape(-1, self.head_v_dim))
|
||||||
@@ -740,27 +767,26 @@ class Qwen3_5MoeSparseBlock(nn.Module):
|
|||||||
if T == 1:
|
if T == 1:
|
||||||
# Fast path: single token (decode).
|
# Fast path: single token (decode).
|
||||||
# Batched GEMM: replace top_k separate F.linear calls with 2 fused ops.
|
# Batched GEMM: replace top_k separate F.linear calls with 2 fused ops.
|
||||||
# gate_up: 1 large GEMM (1,H) × (H, K*2*I) → (1, K*2*I)
|
# gate_up: 1 large GEMM (1,H) × (K*2*I,H)^T → (1, K*2*I)
|
||||||
# down: 1 einsum (K,I) × (K,H,I) → (K,H)
|
# down: 1 bmm (K,H,I) @ (K,I,1) → (K,H)
|
||||||
# Total: 3 kernel launches vs previous 16 (top_k*2).
|
# Total: 3 kernel launches vs previous 16 (top_k*2).
|
||||||
eids = topk_ids[0] # (K,)
|
eids = topk_ids[0] # (K,)
|
||||||
ws = topk_weights[0].to(hidden_states.dtype) # (K,)
|
ws = topk_weights[0].to(hidden_states.dtype) # (K,)
|
||||||
w13_sel = w13[eids] # (K, 2*I, H)
|
w13_sel = w13[eids] # (K, 2*I, H)
|
||||||
w2_sel = w2[eids] # (K, H, I)
|
w2_sel = w2[eids] # (K, H, I)
|
||||||
|
|
||||||
H = hidden_states.shape[-1]
|
H = hidden_states.shape[-1]
|
||||||
K2I = w13_sel.shape[1] # K * (2*I) after reshape
|
|
||||||
|
|
||||||
gate_up = F.linear(
|
gate_up = F.linear(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
w13_sel.reshape(-1, H).contiguous(), # (K*2*I, H)
|
w13_sel.reshape(-1, H), # (K*2*I, H) — contiguous after indexing
|
||||||
) # (1, K*2*I)
|
) # (1, K*2*I)
|
||||||
gate_up = gate_up.view(self.top_k, -1) # (K, 2*I)
|
gate_up = gate_up.view(self.top_k, -1) # (K, 2*I)
|
||||||
gate, up = gate_up.chunk(2, dim=-1) # (K, I) each
|
gate, up = gate_up.chunk(2, dim=-1) # (K, I) each
|
||||||
act = F.silu(gate) * up # (K, I)
|
act = F.silu(gate) * up # (K, I)
|
||||||
|
|
||||||
# einsum: result[k,h] = sum_i act[k,i] * w2_sel[k,h,i]
|
# bmm: (K,H,I) @ (K,I,1) → (K,H,1) → (K,H)
|
||||||
expert_out = torch.einsum('ki,khi->kh', act, w2_sel) # (K, H)
|
expert_out = torch.bmm(w2_sel, act.unsqueeze(-1)).squeeze(-1) # (K, H)
|
||||||
|
|
||||||
out = (expert_out * ws.unsqueeze(-1)).sum(0, keepdim=True).to(
|
out = (expert_out * ws.unsqueeze(-1)).sum(0, keepdim=True).to(
|
||||||
hidden_states.dtype) # (1, H)
|
hidden_states.dtype) # (1, H)
|
||||||
|
|||||||
Reference in New Issue
Block a user