forked from EngineX-Cambricon/enginex-mlu370-vllm
add gemma3
This commit is contained in:
@@ -237,7 +237,13 @@ class Gemma3Attention(nn.Module):
|
|||||||
k = self.k_norm(k)
|
k = self.k_norm(k)
|
||||||
k = k.flatten(-2, -1)
|
k = k.flatten(-2, -1)
|
||||||
|
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
# MLU rotary_emb expects a single concatenated tensor, not
|
||||||
|
# separate q and k (forward_mlu signature differs from forward_native).
|
||||||
|
qk = torch.cat([q, k], dim=-1)
|
||||||
|
self.rotary_emb(positions,
|
||||||
|
qk.view(-1, self.num_heads + self.num_kv_heads,
|
||||||
|
self.head_dim))
|
||||||
|
q, k = qk.split([self.q_size, self.kv_size], dim=-1)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|||||||
Reference in New Issue
Block a user