初始化项目,由ModelHub XC社区提供模型
Model: transformers-community/sink_cache Source: Original Platform
This commit is contained in:
231
custom_generate/generate.py
Normal file
231
custom_generate/generate.py
Normal file
@@ -0,0 +1,231 @@
|
||||
import torch
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from transformers import Cache, GenerationConfig
|
||||
|
||||
|
||||
UNSUPPORTED_GENERATION_ARGS = [
|
||||
"cache_implementation", # cache-related arguments, here we always use SinkCache
|
||||
"cache_config",
|
||||
"return_legacy_cache",
|
||||
"num_beams", # beam search (and cousin techniques) are not supported
|
||||
"compile_config", # SinkCache doesn't support torch.compile
|
||||
"assistant_model", # it also doesn't support speculative decoding
|
||||
]
|
||||
|
||||
class SinkCache(Cache):
|
||||
"""
|
||||
A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
|
||||
generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
|
||||
tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
|
||||
|
||||
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
|
||||
`[batch_size, num_heads, seq_len, head_dim]`.
|
||||
|
||||
This class was copied from transformers 4.52.0, with minor modifications.
|
||||
|
||||
Parameters:
|
||||
window_length (`int`):
|
||||
The length of the context window.
|
||||
num_sink_tokens (`int`):
|
||||
The number of sink tokens. See the original paper for more information.
|
||||
"""
|
||||
|
||||
def __init__(self, window_length: int, num_sink_tokens: int) -> None:
|
||||
super().__init__()
|
||||
self.key_cache: List[torch.Tensor] = []
|
||||
self.value_cache: List[torch.Tensor] = []
|
||||
self.window_length = window_length
|
||||
self.num_sink_tokens = num_sink_tokens
|
||||
self.cos_sin_rerotation_cache = {}
|
||||
self._cos_cache = None
|
||||
self._sin_cache = None
|
||||
|
||||
@staticmethod
|
||||
def _rotate_half(x):
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
def _apply_key_rotary_pos_emb(
|
||||
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
|
||||
return rotated_key_states
|
||||
|
||||
def _get_rerotation_cos_sin(
|
||||
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if key_states.shape[-2] not in self.cos_sin_rerotation_cache:
|
||||
# Upcast to float32 temporarily for better accuracy
|
||||
cos = cos.to(torch.float32)
|
||||
sin = sin.to(torch.float32)
|
||||
|
||||
# Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
|
||||
original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
|
||||
shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
|
||||
original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
|
||||
shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
|
||||
rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
|
||||
rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
|
||||
|
||||
self.cos_sin_rerotation_cache[key_states.shape[-2]] = (
|
||||
rerotation_cos.to(key_states.dtype).unsqueeze(0),
|
||||
rerotation_sin.to(key_states.dtype).unsqueeze(0),
|
||||
)
|
||||
return self.cos_sin_rerotation_cache[key_states.shape[-2]]
|
||||
|
||||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||||
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
||||
if len(self.key_cache) <= layer_idx:
|
||||
return 0
|
||||
return self.key_cache[layer_idx].shape[-2]
|
||||
|
||||
def get_max_cache_shape(self) -> Optional[int]:
|
||||
"""Returns the maximum sequence length of the cache object, in case of SinkCache it is the window length."""
|
||||
return self.window_length
|
||||
|
||||
def update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
layer_idx: int,
|
||||
cache_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
||||
|
||||
Parameters:
|
||||
key_states (`torch.Tensor`):
|
||||
The new key states to cache.
|
||||
value_states (`torch.Tensor`):
|
||||
The new value states to cache.
|
||||
layer_idx (`int`):
|
||||
The index of the layer to cache the states for.
|
||||
cache_kwargs (`Dict[str, Any]`, `optional`):
|
||||
Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
|
||||
`cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
|
||||
rotation as the tokens are shifted.
|
||||
|
||||
Return:
|
||||
A tuple containing the updated key and value states.
|
||||
"""
|
||||
# Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
|
||||
# with partially rotated position embeddings, like Phi or Persimmon.
|
||||
if cache_kwargs is None:
|
||||
cache_kwargs = {}
|
||||
sin = cache_kwargs.get("sin")
|
||||
cos = cache_kwargs.get("cos")
|
||||
partial_rotation_size = cache_kwargs.get("partial_rotation_size")
|
||||
using_rope = cos is not None and sin is not None
|
||||
|
||||
# Update the sin/cos cache, which holds sin/cos values for all possible positions
|
||||
if using_rope and layer_idx == 0:
|
||||
# BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove
|
||||
# after all RoPE models have a llama-like cache utilization.
|
||||
if cos.dim() == 2:
|
||||
self._cos_cache = cos
|
||||
self._sin_cache = sin
|
||||
else:
|
||||
if self._cos_cache is None:
|
||||
self._cos_cache = cos[0, ...]
|
||||
self._sin_cache = sin[0, ...]
|
||||
elif self._cos_cache.shape[0] < self.window_length:
|
||||
self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0)
|
||||
self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0)
|
||||
|
||||
# [bsz, num_heads, seq_len, head_dim]
|
||||
if len(self.key_cache) <= layer_idx:
|
||||
# Empty cache
|
||||
self.key_cache.append(key_states)
|
||||
self.value_cache.append(value_states)
|
||||
|
||||
elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
|
||||
# Growing cache
|
||||
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
|
||||
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
|
||||
|
||||
else:
|
||||
# Shifting cache
|
||||
keys_to_keep = self.key_cache[layer_idx][
|
||||
:, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
|
||||
]
|
||||
|
||||
# On RoPE models, we need to recompute the Key rotation as the tokens are shifted
|
||||
if using_rope:
|
||||
rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
|
||||
key_states, self._cos_cache[: self.window_length], self._sin_cache[: self.window_length]
|
||||
)
|
||||
if partial_rotation_size is not None:
|
||||
keys_to_keep, keys_pass = (
|
||||
keys_to_keep[..., :partial_rotation_size],
|
||||
keys_to_keep[..., partial_rotation_size:],
|
||||
)
|
||||
keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
|
||||
if partial_rotation_size is not None:
|
||||
keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
|
||||
|
||||
# Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
|
||||
sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
|
||||
self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
|
||||
|
||||
sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
|
||||
values_to_keep = self.value_cache[layer_idx][
|
||||
:, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
|
||||
]
|
||||
self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
|
||||
|
||||
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
||||
|
||||
|
||||
def generate(model, window_length=256, num_sink_tokens=4, **kwargs):
|
||||
"""Custom generate function for SinkCache.
|
||||
|
||||
Args:
|
||||
model (`PreTrainedModel`):
|
||||
The model to generate from.
|
||||
window_length (`int`, *optional*, defaults to 256):
|
||||
The length of the context window.
|
||||
num_sink_tokens (`int`, *optional*, defaults to 4):
|
||||
The number of sink tokens. See the original paper for more information.
|
||||
"""
|
||||
# 1. General sanity checks
|
||||
# 1.a. A few arguments are not allowed, especially arguments that control caches.
|
||||
generation_config = kwargs.get("generation_config")
|
||||
default_global_generation_config = GenerationConfig()
|
||||
default_model_generation_config = model.generation_config
|
||||
for arg in UNSUPPORTED_GENERATION_ARGS:
|
||||
has_custom_gen_config_arg = (
|
||||
generation_config is not None
|
||||
# = and not (match global default or match model-specific default)
|
||||
and not (
|
||||
getattr(default_model_generation_config, arg) == getattr(generation_config, arg)
|
||||
or getattr(default_global_generation_config, arg) == getattr(generation_config, arg)
|
||||
)
|
||||
)
|
||||
kwargs_has_arg = arg in kwargs and kwargs[arg] is not None
|
||||
if kwargs_has_arg or has_custom_gen_config_arg:
|
||||
raise ValueError(
|
||||
f"`{arg}` is set, but it's not supported in this custom generate function. List of "
|
||||
f"unsupported arguments: {UNSUPPORTED_GENERATION_ARGS}"
|
||||
)
|
||||
|
||||
# 1.b. The model must be decoder-only
|
||||
if model.config.is_encoder_decoder:
|
||||
raise ValueError("This custom generate function only works with decoder-only models")
|
||||
|
||||
# 1.c. compatibility with transformers 4.52: we must pop `custom_generate` from kwargs, otherwise it will result
|
||||
# in an infinite loop when we call `model.generate`. This is solved in transformers 4.53.
|
||||
kwargs.pop("custom_generate", None)
|
||||
|
||||
# 2. Generate with SinkCache
|
||||
# 2.a. prepare the cache, if it was not passed.
|
||||
past_key_values = kwargs.pop("past_key_values", None)
|
||||
if past_key_values is None:
|
||||
past_key_values = SinkCache(window_length=window_length, num_sink_tokens=num_sink_tokens)
|
||||
elif not isinstance(past_key_values, SinkCache):
|
||||
raise ValueError(f"`past_key_values` must be a `SinkCache` instance, got a {type(past_key_values)} instance")
|
||||
|
||||
# 2.b. generate with the cache
|
||||
generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True)
|
||||
return generation_outputs
|
||||
Reference in New Issue
Block a user