Urgent model support: support gemma-3-it (#4424)
This commit is contained in:
@@ -1173,6 +1173,37 @@ def get_rope(
|
||||
return rotary_emb
|
||||
|
||||
|
||||
# Copied from transformers
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
unsqueeze_dim=1,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
orig_q_dtype = q.dtype
|
||||
orig_k_dtype = k.dtype
|
||||
q, k = q.float(), k.float()
|
||||
|
||||
# embedding is performed in float
|
||||
cos = cos.unsqueeze(unsqueeze_dim).float()
|
||||
sin = sin.unsqueeze(unsqueeze_dim).float()
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
|
||||
q_embed = q_embed.to(orig_q_dtype)
|
||||
k_embed = k_embed.to(orig_k_dtype)
|
||||
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def get_rope_cpu(
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
|
||||
Reference in New Issue
Block a user