232 lines
11 KiB
Python
232 lines
11 KiB
Python
|
|
import torch
|
||
|
|
from typing import Any, Dict, List, Optional, Tuple
|
||
|
|
|
||
|
|
from transformers import Cache, GenerationConfig
|
||
|
|
|
||
|
|
|
||
|
|
UNSUPPORTED_GENERATION_ARGS = [
|
||
|
|
"cache_implementation", # cache-related arguments, here we always use SinkCache
|
||
|
|
"cache_config",
|
||
|
|
"return_legacy_cache",
|
||
|
|
"num_beams", # beam search (and cousin techniques) are not supported
|
||
|
|
"compile_config", # SinkCache doesn't support torch.compile
|
||
|
|
"assistant_model", # it also doesn't support speculative decoding
|
||
|
|
]
|
||
|
|
|
||
|
|
class SinkCache(Cache):
|
||
|
|
"""
|
||
|
|
A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
|
||
|
|
generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
|
||
|
|
tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
|
||
|
|
|
||
|
|
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
|
||
|
|
`[batch_size, num_heads, seq_len, head_dim]`.
|
||
|
|
|
||
|
|
This class was copied from transformers 4.52.0, with minor modifications.
|
||
|
|
|
||
|
|
Parameters:
|
||
|
|
window_length (`int`):
|
||
|
|
The length of the context window.
|
||
|
|
num_sink_tokens (`int`):
|
||
|
|
The number of sink tokens. See the original paper for more information.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, window_length: int, num_sink_tokens: int) -> None:
|
||
|
|
super().__init__()
|
||
|
|
self.key_cache: List[torch.Tensor] = []
|
||
|
|
self.value_cache: List[torch.Tensor] = []
|
||
|
|
self.window_length = window_length
|
||
|
|
self.num_sink_tokens = num_sink_tokens
|
||
|
|
self.cos_sin_rerotation_cache = {}
|
||
|
|
self._cos_cache = None
|
||
|
|
self._sin_cache = None
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def _rotate_half(x):
|
||
|
|
x1 = x[..., : x.shape[-1] // 2]
|
||
|
|
x2 = x[..., x.shape[-1] // 2 :]
|
||
|
|
return torch.cat((-x2, x1), dim=-1)
|
||
|
|
|
||
|
|
def _apply_key_rotary_pos_emb(
|
||
|
|
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
||
|
|
) -> torch.Tensor:
|
||
|
|
rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
|
||
|
|
return rotated_key_states
|
||
|
|
|
||
|
|
def _get_rerotation_cos_sin(
|
||
|
|
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
||
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
|
|
if key_states.shape[-2] not in self.cos_sin_rerotation_cache:
|
||
|
|
# Upcast to float32 temporarily for better accuracy
|
||
|
|
cos = cos.to(torch.float32)
|
||
|
|
sin = sin.to(torch.float32)
|
||
|
|
|
||
|
|
# Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
|
||
|
|
original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
|
||
|
|
shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
|
||
|
|
original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
|
||
|
|
shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
|
||
|
|
rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
|
||
|
|
rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
|
||
|
|
|
||
|
|
self.cos_sin_rerotation_cache[key_states.shape[-2]] = (
|
||
|
|
rerotation_cos.to(key_states.dtype).unsqueeze(0),
|
||
|
|
rerotation_sin.to(key_states.dtype).unsqueeze(0),
|
||
|
|
)
|
||
|
|
return self.cos_sin_rerotation_cache[key_states.shape[-2]]
|
||
|
|
|
||
|
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||
|
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
||
|
|
if len(self.key_cache) <= layer_idx:
|
||
|
|
return 0
|
||
|
|
return self.key_cache[layer_idx].shape[-2]
|
||
|
|
|
||
|
|
def get_max_cache_shape(self) -> Optional[int]:
|
||
|
|
"""Returns the maximum sequence length of the cache object, in case of SinkCache it is the window length."""
|
||
|
|
return self.window_length
|
||
|
|
|
||
|
|
def update(
|
||
|
|
self,
|
||
|
|
key_states: torch.Tensor,
|
||
|
|
value_states: torch.Tensor,
|
||
|
|
layer_idx: int,
|
||
|
|
cache_kwargs: Optional[Dict[str, Any]] = None,
|
||
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
|
|
"""
|
||
|
|
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
||
|
|
|
||
|
|
Parameters:
|
||
|
|
key_states (`torch.Tensor`):
|
||
|
|
The new key states to cache.
|
||
|
|
value_states (`torch.Tensor`):
|
||
|
|
The new value states to cache.
|
||
|
|
layer_idx (`int`):
|
||
|
|
The index of the layer to cache the states for.
|
||
|
|
cache_kwargs (`Dict[str, Any]`, `optional`):
|
||
|
|
Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
|
||
|
|
`cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
|
||
|
|
rotation as the tokens are shifted.
|
||
|
|
|
||
|
|
Return:
|
||
|
|
A tuple containing the updated key and value states.
|
||
|
|
"""
|
||
|
|
# Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
|
||
|
|
# with partially rotated position embeddings, like Phi or Persimmon.
|
||
|
|
if cache_kwargs is None:
|
||
|
|
cache_kwargs = {}
|
||
|
|
sin = cache_kwargs.get("sin")
|
||
|
|
cos = cache_kwargs.get("cos")
|
||
|
|
partial_rotation_size = cache_kwargs.get("partial_rotation_size")
|
||
|
|
using_rope = cos is not None and sin is not None
|
||
|
|
|
||
|
|
# Update the sin/cos cache, which holds sin/cos values for all possible positions
|
||
|
|
if using_rope and layer_idx == 0:
|
||
|
|
# BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove
|
||
|
|
# after all RoPE models have a llama-like cache utilization.
|
||
|
|
if cos.dim() == 2:
|
||
|
|
self._cos_cache = cos
|
||
|
|
self._sin_cache = sin
|
||
|
|
else:
|
||
|
|
if self._cos_cache is None:
|
||
|
|
self._cos_cache = cos[0, ...]
|
||
|
|
self._sin_cache = sin[0, ...]
|
||
|
|
elif self._cos_cache.shape[0] < self.window_length:
|
||
|
|
self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0)
|
||
|
|
self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0)
|
||
|
|
|
||
|
|
# [bsz, num_heads, seq_len, head_dim]
|
||
|
|
if len(self.key_cache) <= layer_idx:
|
||
|
|
# Empty cache
|
||
|
|
self.key_cache.append(key_states)
|
||
|
|
self.value_cache.append(value_states)
|
||
|
|
|
||
|
|
elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
|
||
|
|
# Growing cache
|
||
|
|
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
|
||
|
|
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
|
||
|
|
|
||
|
|
else:
|
||
|
|
# Shifting cache
|
||
|
|
keys_to_keep = self.key_cache[layer_idx][
|
||
|
|
:, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
|
||
|
|
]
|
||
|
|
|
||
|
|
# On RoPE models, we need to recompute the Key rotation as the tokens are shifted
|
||
|
|
if using_rope:
|
||
|
|
rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
|
||
|
|
key_states, self._cos_cache[: self.window_length], self._sin_cache[: self.window_length]
|
||
|
|
)
|
||
|
|
if partial_rotation_size is not None:
|
||
|
|
keys_to_keep, keys_pass = (
|
||
|
|
keys_to_keep[..., :partial_rotation_size],
|
||
|
|
keys_to_keep[..., partial_rotation_size:],
|
||
|
|
)
|
||
|
|
keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
|
||
|
|
if partial_rotation_size is not None:
|
||
|
|
keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
|
||
|
|
|
||
|
|
# Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
|
||
|
|
sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
|
||
|
|
self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
|
||
|
|
|
||
|
|
sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
|
||
|
|
values_to_keep = self.value_cache[layer_idx][
|
||
|
|
:, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
|
||
|
|
]
|
||
|
|
self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
|
||
|
|
|
||
|
|
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
||
|
|
|
||
|
|
|
||
|
|
def generate(model, window_length=256, num_sink_tokens=4, **kwargs):
|
||
|
|
"""Custom generate function for SinkCache.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
model (`PreTrainedModel`):
|
||
|
|
The model to generate from.
|
||
|
|
window_length (`int`, *optional*, defaults to 256):
|
||
|
|
The length of the context window.
|
||
|
|
num_sink_tokens (`int`, *optional*, defaults to 4):
|
||
|
|
The number of sink tokens. See the original paper for more information.
|
||
|
|
"""
|
||
|
|
# 1. General sanity checks
|
||
|
|
# 1.a. A few arguments are not allowed, especially arguments that control caches.
|
||
|
|
generation_config = kwargs.get("generation_config")
|
||
|
|
default_global_generation_config = GenerationConfig()
|
||
|
|
default_model_generation_config = model.generation_config
|
||
|
|
for arg in UNSUPPORTED_GENERATION_ARGS:
|
||
|
|
has_custom_gen_config_arg = (
|
||
|
|
generation_config is not None
|
||
|
|
# = and not (match global default or match model-specific default)
|
||
|
|
and not (
|
||
|
|
getattr(default_model_generation_config, arg) == getattr(generation_config, arg)
|
||
|
|
or getattr(default_global_generation_config, arg) == getattr(generation_config, arg)
|
||
|
|
)
|
||
|
|
)
|
||
|
|
kwargs_has_arg = arg in kwargs and kwargs[arg] is not None
|
||
|
|
if kwargs_has_arg or has_custom_gen_config_arg:
|
||
|
|
raise ValueError(
|
||
|
|
f"`{arg}` is set, but it's not supported in this custom generate function. List of "
|
||
|
|
f"unsupported arguments: {UNSUPPORTED_GENERATION_ARGS}"
|
||
|
|
)
|
||
|
|
|
||
|
|
# 1.b. The model must be decoder-only
|
||
|
|
if model.config.is_encoder_decoder:
|
||
|
|
raise ValueError("This custom generate function only works with decoder-only models")
|
||
|
|
|
||
|
|
# 1.c. compatibility with transformers 4.52: we must pop `custom_generate` from kwargs, otherwise it will result
|
||
|
|
# in an infinite loop when we call `model.generate`. This is solved in transformers 4.53.
|
||
|
|
kwargs.pop("custom_generate", None)
|
||
|
|
|
||
|
|
# 2. Generate with SinkCache
|
||
|
|
# 2.a. prepare the cache, if it was not passed.
|
||
|
|
past_key_values = kwargs.pop("past_key_values", None)
|
||
|
|
if past_key_values is None:
|
||
|
|
past_key_values = SinkCache(window_length=window_length, num_sink_tokens=num_sink_tokens)
|
||
|
|
elif not isinstance(past_key_values, SinkCache):
|
||
|
|
raise ValueError(f"`past_key_values` must be a `SinkCache` instance, got a {type(past_key_values)} instance")
|
||
|
|
|
||
|
|
# 2.b. generate with the cache
|
||
|
|
generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True)
|
||
|
|
return generation_outputs
|