init
This commit is contained in:
531
vllm/model_executor/layers/rotary_embedding.py
Normal file
531
vllm/model_executor/layers/rotary_embedding.py
Normal file
@@ -0,0 +1,531 @@
|
||||
# coding=utf-8
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Rotary Positional Embeddings."""
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
|
||||
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
||||
x1 = x[..., :x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2:]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
x = torch.stack((-x2, x1), dim=-1)
|
||||
return x.flatten(-2)
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
"""Original rotary positional embedding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_size = head_size
|
||||
self.rotary_dim = rotary_dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.is_neox_style = is_neox_style
|
||||
|
||||
cache = self._compute_cos_sin_cache()
|
||||
cache = cache.to(torch.get_default_dtype())
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||
"""Compute the inverse frequency."""
|
||||
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
|
||||
# However, we use `torch.arange(..., dtype=torch.float)` instead to
|
||||
# avoid numerical issues with large base values (e.g., 10000000).
|
||||
# This may cause a slight numerical difference between the HF
|
||||
# implementation and ours.
|
||||
# NOTE(woosuk): To exactly match the HF implementation, we need to
|
||||
# use CPU to compute the cache and then move it to GPU. However, we
|
||||
# create the cache on GPU for faster initialization. This may cause
|
||||
# a slight numerical difference between the HF implementation and ours.
|
||||
|
||||
# torch_musa did not support pow_scalar_out
|
||||
# inv_freq = 1.0 / (base**(torch.arange(
|
||||
# 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
|
||||
|
||||
exp = torch.arange(0, self.rotary_dim, 2, dtype=torch.float)
|
||||
device = exp.device
|
||||
inv_freq = 1.0 / (base**(exp.cpu() / self.rotary_dim))
|
||||
return inv_freq.to(device)
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
"""Compute the cos and sin cache."""
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
|
||||
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
query = query.view(*query.shape[:-1], -1, self.head_size)
|
||||
key = key.view(*key.shape[:-1], -1, self.head_size)
|
||||
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
if self.rotary_dim < self.head_size:
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
|
||||
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
|
||||
positions.device)
|
||||
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
|
||||
if offsets is not None else positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
if self.is_neox_style:
|
||||
# NOTE(woosuk): Here we assume that the positions tensor has the
|
||||
# shape [batch_size, seq_len].
|
||||
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
|
||||
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
|
||||
else:
|
||||
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
||||
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
||||
|
||||
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
|
||||
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
|
||||
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
|
||||
|
||||
if self.rotary_dim < self.head_size:
|
||||
query = torch.cat((query_rot, query_pass), dim=-1)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1)
|
||||
else:
|
||||
query = query_rot
|
||||
key = key_rot
|
||||
query = query.flatten(-2)
|
||||
key = key.flatten(-2)
|
||||
return query, key
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
|
||||
# ops.rotary_embedding()/batched_rotary_embedding()
|
||||
# are in-place operations that update the query and key tensors.
|
||||
if offsets is not None:
|
||||
ops.batched_rotary_embedding(positions, query, key, self.head_size,
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style, self.rotary_dim,
|
||||
offsets)
|
||||
else:
|
||||
ops.rotary_embedding(positions, query, key, self.head_size,
|
||||
self.cos_sin_cache, self.is_neox_style)
|
||||
return query, key
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
|
||||
s += f", max_position_embeddings={self.max_position_embeddings}"
|
||||
s += f", base={self.base}, is_neox_style={self.is_neox_style}"
|
||||
return s
|
||||
|
||||
|
||||
class LinearScalingRotaryEmbedding(RotaryEmbedding):
|
||||
"""RotaryEmbedding extended with linear scaling.
|
||||
|
||||
Credits to the Reddit user /u/kaiokendev
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
scaling_factors: Union[List[float], float],
|
||||
) -> None:
|
||||
if isinstance(scaling_factors, float):
|
||||
scaling_factors = [scaling_factors]
|
||||
self.scaling_factors = scaling_factors
|
||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||
is_neox_style)
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
cache_list = []
|
||||
for scaling_factor in self.scaling_factors:
|
||||
# NOTE(woosuk): self.max_position_embeddings is the original
|
||||
# maximum length before applying the rope scaling.
|
||||
# Thus, the maximum length after applying the rope scaling is
|
||||
# self.max_position_embeddings * self.scaling_factor.
|
||||
max_len = self.max_position_embeddings * scaling_factor
|
||||
t = torch.arange(max_len, dtype=torch.float)
|
||||
t = t / scaling_factor
|
||||
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
cache_list.append(cache)
|
||||
return torch.cat(cache_list, dim=0)
|
||||
|
||||
|
||||
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
|
||||
"""RotaryEmbedding extended with Dynamic NTK scaling.
|
||||
|
||||
Credits to the Reddit users /u/bloc97 and /u/emozilla
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
scaling_factor: float,
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||
is_neox_style)
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
# NOTE(woosuk): self.max_position_embeddings is the original
|
||||
# maximum length before applying the rope scaling.
|
||||
# Thus, the maximum length after applying the rope scaling is
|
||||
# self.max_position_embeddings * self.scaling_factor.
|
||||
max_len = self.max_position_embeddings * self.scaling_factor
|
||||
base = self.base * (
|
||||
(self.scaling_factor * max_len / self.max_position_embeddings) -
|
||||
(self.scaling_factor - 1))**(self.rotary_dim /
|
||||
(self.rotary_dim - 2))
|
||||
inv_freq = self._compute_inv_freq(base)
|
||||
t = torch.arange(max_len, dtype=torch.float)
|
||||
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
|
||||
# Inverse dim formula to find dim based on number of rotations
|
||||
def _yarn_find_correction_dim(num_rotations: int,
|
||||
dim: int,
|
||||
base: float = 10000,
|
||||
max_position_embeddings: int = 2048) -> float:
|
||||
return (dim * math.log(max_position_embeddings /
|
||||
(num_rotations * 2 * math.pi))) / (2 *
|
||||
math.log(base))
|
||||
|
||||
|
||||
# Find dim range bounds based on rotations
|
||||
def _yarn_find_correction_range(
|
||||
low_rot: int,
|
||||
high_rot: int,
|
||||
dim: int,
|
||||
base: float = 10000,
|
||||
max_position_embeddings: int = 2048) -> Tuple[int, int]:
|
||||
low = math.floor(
|
||||
_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
||||
high = math.ceil(
|
||||
_yarn_find_correction_dim(high_rot, dim, base,
|
||||
max_position_embeddings))
|
||||
return max(low, 0), min(high, dim - 1) # Clamp values just in case
|
||||
|
||||
|
||||
def _yarn_linear_ramp_mask(low: float, high: float, dim: int,
|
||||
dtype: torch.dtype) -> torch.Tensor:
|
||||
if low == high:
|
||||
high += 0.001 # Prevent singularity
|
||||
|
||||
linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low)
|
||||
ramp_func = torch.clamp(linear_func, 0, 1)
|
||||
return ramp_func
|
||||
|
||||
|
||||
def _yarn_get_mscale(scale: float = 1) -> float:
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * math.log(scale) + 1.0
|
||||
|
||||
|
||||
class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
||||
"""RotaryEmbedding extended with YaRN method.
|
||||
|
||||
Credits to Peng et al. github.com/jquesnelle/yarn
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
scaling_factor: float,
|
||||
*,
|
||||
extrapolation_factor: float = 1,
|
||||
attn_factor: float = 1,
|
||||
beta_fast: int = 32,
|
||||
beta_slow: int = 1,
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
self.extrapolation_factor = extrapolation_factor
|
||||
self.attn_factor = attn_factor
|
||||
self.beta_fast = beta_fast
|
||||
self.beta_slow = beta_slow
|
||||
# Get n-d magnitude scaling corrected for interpolation
|
||||
self.mscale = float(
|
||||
_yarn_get_mscale(self.scaling_factor) * attn_factor)
|
||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||
is_neox_style)
|
||||
|
||||
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
||||
pos_freqs = self.base**(
|
||||
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) /
|
||||
self.rotary_dim)
|
||||
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
|
||||
|
||||
low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
|
||||
self.rotary_dim, self.base,
|
||||
self.max_position_embeddings)
|
||||
# Get n-d rotational scaling corrected for extrapolation
|
||||
inv_freq_mask = (1 - _yarn_linear_ramp_mask(
|
||||
low, high, self.rotary_dim // 2,
|
||||
dtype=torch.float)) * self.extrapolation_factor
|
||||
inv_freq = inv_freq_interpolation * (
|
||||
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
||||
return inv_freq
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
inv_freq = self._compute_inv_freq(self.scaling_factor)
|
||||
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
|
||||
dtype=torch.float32)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = (freqs.cos() * self.mscale)
|
||||
sin = (freqs.sin() * self.mscale)
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
|
||||
class Phi3SuScaledRotaryEmbedding(nn.Module):
|
||||
"""Phi3 family of models scaled rotary embedding.
|
||||
|
||||
Based on the original RotaryEmbedding implementation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
original_max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
short_factor: List[float],
|
||||
long_factor: List[float],
|
||||
short_mscale: float = 1.1,
|
||||
long_mscale: float = 1.225,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if rotary_dim != head_size:
|
||||
raise ValueError(
|
||||
f"`Phi3SuScaledRotaryEmbedding` does not support rotary_dim != \
|
||||
head_size ({rotary_dim}!={head_size}).")
|
||||
if is_neox_style is False:
|
||||
raise ValueError(
|
||||
"`Phi3SuScaledRotaryEmbedding` only supports neox_style.")
|
||||
|
||||
self.head_size = head_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.original_max_position_embeddings = original_max_position_embeddings
|
||||
self.base = base
|
||||
self.short_factor = short_factor
|
||||
self.long_factor = long_factor
|
||||
self.short_mscale = short_mscale
|
||||
self.long_mscale = long_mscale
|
||||
|
||||
short_cache = self._compute_cos_sin_cache(
|
||||
original_max_position_embeddings, short_factor, short_mscale)
|
||||
short_cache = short_cache.to(torch.get_default_dtype())
|
||||
self.register_buffer("short_cos_sin_cache",
|
||||
short_cache,
|
||||
persistent=False)
|
||||
|
||||
long_cache = self._compute_cos_sin_cache(max_position_embeddings,
|
||||
long_factor, long_mscale)
|
||||
long_cache = long_cache.to(torch.get_default_dtype())
|
||||
self.register_buffer("long_cos_sin_cache",
|
||||
long_cache,
|
||||
persistent=False)
|
||||
|
||||
long_short_cache = torch.cat(
|
||||
[self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0)
|
||||
self.register_buffer("long_short_cos_sin_cache",
|
||||
long_short_cache,
|
||||
persistent=False)
|
||||
|
||||
def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor:
|
||||
rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32)
|
||||
inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange(
|
||||
0, self.head_size, 2, dtype=torch.float) / self.head_size)))
|
||||
return inv_freq
|
||||
|
||||
def _compute_cos_sin_cache(
|
||||
self,
|
||||
max_position_embeddings: int,
|
||||
rescale_factors: List[float],
|
||||
mscale: float,
|
||||
) -> torch.Tensor:
|
||||
inv_freq = self._compute_inv_freq(rescale_factors)
|
||||
t = torch.arange(max_position_embeddings, dtype=torch.float)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos() * mscale
|
||||
sin = freqs.sin() * mscale
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
query = query.view(*query.shape[:-1], -1, self.head_size)
|
||||
key = key.view(*key.shape[:-1], -1, self.head_size)
|
||||
|
||||
k = self.original_max_position_embeddings
|
||||
long_prompt_offset = (torch.any(positions > k).float() *
|
||||
torch.full_like(positions, k)).long()
|
||||
idx = (torch.add(positions, long_prompt_offset)
|
||||
if long_prompt_offset is not None else positions)
|
||||
self.long_short_cos_sin_cache: torch.Tensor = (
|
||||
self.long_short_cos_sin_cache.to(idx.device))
|
||||
idx = torch.add(idx, offsets) if offsets is not None else idx
|
||||
cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)
|
||||
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
cos = cos.repeat(1, 2).unsqueeze(-2)
|
||||
sin = sin.repeat(1, 2).unsqueeze(-2)
|
||||
|
||||
query = query * cos + _rotate_neox(query) * sin
|
||||
key = key * cos + _rotate_neox(key) * sin
|
||||
|
||||
return query.flatten(-2), key.flatten(-2)
|
||||
|
||||
|
||||
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
|
||||
|
||||
|
||||
def get_rope(
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position: int,
|
||||
base: int,
|
||||
is_neox_style: bool = True,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
) -> RotaryEmbedding:
|
||||
if rope_scaling is not None:
|
||||
# Transforms every value that is a list into a tuple for caching calls
|
||||
rope_scaling_tuple = {
|
||||
k: tuple(v) if isinstance(v, list) else v
|
||||
for k, v in rope_scaling.items()
|
||||
}
|
||||
rope_scaling_args = tuple(rope_scaling_tuple.items())
|
||||
else:
|
||||
rope_scaling_args = None
|
||||
key = (head_size, rotary_dim, max_position, base, is_neox_style,
|
||||
rope_scaling_args)
|
||||
if key in _ROPE_DICT:
|
||||
return _ROPE_DICT[key]
|
||||
if rope_scaling is None:
|
||||
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
||||
is_neox_style)
|
||||
else:
|
||||
scaling_type = rope_scaling["type"]
|
||||
if scaling_type != "su":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
if scaling_type == "linear":
|
||||
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
|
||||
max_position, base,
|
||||
is_neox_style,
|
||||
scaling_factor)
|
||||
elif scaling_type == "dynamic":
|
||||
rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
||||
head_size, rotary_dim, max_position, base, is_neox_style,
|
||||
scaling_factor)
|
||||
elif scaling_type == "yarn":
|
||||
original_max_position = rope_scaling[
|
||||
"original_max_position_embeddings"]
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in rope_scaling.items()
|
||||
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
|
||||
"beta_slow")
|
||||
}
|
||||
rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
|
||||
original_max_position,
|
||||
base, is_neox_style,
|
||||
scaling_factor,
|
||||
**extra_kwargs)
|
||||
elif scaling_type == "su":
|
||||
short_factor = rope_scaling["short_factor"]
|
||||
long_factor = rope_scaling["long_factor"]
|
||||
original_max_position = rope_scaling[
|
||||
"original_max_position_embeddings"]
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in rope_scaling.items()
|
||||
if k in ("short_mscale", "long_mscale")
|
||||
}
|
||||
rotary_emb = Phi3SuScaledRotaryEmbedding(
|
||||
head_size, rotary_dim, max_position, original_max_position,
|
||||
base, is_neox_style, short_factor, long_factor, **extra_kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||
_ROPE_DICT[key] = rotary_emb
|
||||
return rotary_emb
|
||||
Reference in New Issue
Block a user