add qwen3_moe

This commit is contained in:
Chranos
2026-02-10 18:09:58 +08:00
parent 8a613d15bd
commit a26729bf7f

View File

@@ -244,15 +244,19 @@ class Qwen3MoeAttention(nn.Module):
dim=-1)
# Qwen3 specific: Apply QK normalization before rotary embedding
# Use .contiguous() to ensure memory layout is compatible with
# MLU's RMSNorm which uses .view() internally.
q_shape = q.shape
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
self.head_dim).contiguous()
q_by_head = self.q_norm(q_by_head)
q = q_by_head.view(q.shape)
q = q_by_head.reshape(q_shape)
k_shape = k.shape
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
self.head_dim).contiguous()
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape)
k = k_by_head.reshape(k_shape)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)